Skip to content

Commit e9ddaa4

Browse files
Merge branch 'v3-phase1-pr5-middleware-options' into v3-phase1-pr6-documentation-linting
2 parents b67ad41 + c7ca941 commit e9ddaa4

File tree

2 files changed

+7
-26
lines changed

2 files changed

+7
-26
lines changed

middleware.go

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,11 @@ import (
88
"github.com/auth0/go-jwt-middleware/v3/core"
99
)
1010

11-
// contextKey is an unexported type for context keys to prevent collisions.
12-
// Only this package can create contextKey values, following Go best practices.
13-
type contextKey int
14-
15-
const (
16-
// claimsContextKey is the key for storing validated JWT claims in the request context.
17-
claimsContextKey contextKey = iota
18-
)
19-
2011
// JWTMiddleware is a middleware that validates JWTs and makes claims available in the request context.
2112
// It wraps the core validation engine and provides HTTP-specific functionality like token extraction
2213
// and error handling.
14+
//
15+
// Claims are stored in the context using core.SetClaims() and can be retrieved using core.GetClaims[T]().
2316
type JWTMiddleware struct {
2417
core *core.Core
2518
errorHandler ErrorHandler
@@ -148,19 +141,7 @@ func (m *JWTMiddleware) applyDefaults() {
148141
// }
149142
// fmt.Println(claims.RegisteredClaims.Subject)
150143
func GetClaims[T any](ctx context.Context) (T, error) {
151-
var zero T
152-
153-
val := ctx.Value(claimsContextKey)
154-
if val == nil {
155-
return zero, fmt.Errorf("claims not found in context")
156-
}
157-
158-
claims, ok := val.(T)
159-
if !ok {
160-
return zero, fmt.Errorf("claims have wrong type: expected %T, got %T", zero, val)
161-
}
162-
163-
return claims, nil
144+
return core.GetClaims[T](ctx)
164145
}
165146

166147
// MustGetClaims retrieves claims from the context or panics.
@@ -171,7 +152,7 @@ func GetClaims[T any](ctx context.Context) (T, error) {
171152
// claims := jwtmiddleware.MustGetClaims[*validator.ValidatedClaims](r.Context())
172153
// fmt.Println(claims.RegisteredClaims.Subject)
173154
func MustGetClaims[T any](ctx context.Context) T {
174-
claims, err := GetClaims[T](ctx)
155+
claims, err := core.GetClaims[T](ctx)
175156
if err != nil {
176157
panic(err)
177158
}
@@ -187,7 +168,7 @@ func MustGetClaims[T any](ctx context.Context) T {
187168
// // Use claims...
188169
// }
189170
func HasClaims(ctx context.Context) bool {
190-
return ctx.Value(claimsContextKey) != nil
171+
return core.HasClaims(ctx)
191172
}
192173

193174
// CheckJWT is the main JWTMiddleware function which performs the main logic. It
@@ -267,7 +248,7 @@ func (m *JWTMiddleware) CheckJWT(next http.Handler) http.Handler {
267248
if m.logger != nil {
268249
m.logger.Debug("JWT validation successful, setting claims in context")
269250
}
270-
r = r.Clone(context.WithValue(r.Context(), claimsContextKey, validToken))
251+
r = r.Clone(core.SetClaims(r.Context(), validToken))
271252
next.ServeHTTP(w, r)
272253
})
273254
}

option_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ func Test_GetClaims(t *testing.T) {
591591
return createContextWithClaims(wrongClaims)
592592
},
593593
wantErr: true,
594-
errMsg: "claims have wrong type",
594+
errMsg: "claims type assertion failed",
595595
},
596596
}
597597

0 commit comments

Comments
 (0)