@@ -8,18 +8,11 @@ import (
88 "github.com/auth0/go-jwt-middleware/v3/core"
99)
1010
11- // contextKey is an unexported type for context keys to prevent collisions.
12- // Only this package can create contextKey values, following Go best practices.
13- type contextKey int
14-
15- const (
16- // claimsContextKey is the key for storing validated JWT claims in the request context.
17- claimsContextKey contextKey = iota
18- )
19-
2011// JWTMiddleware is a middleware that validates JWTs and makes claims available in the request context.
2112// It wraps the core validation engine and provides HTTP-specific functionality like token extraction
2213// and error handling.
14+ //
15+ // Claims are stored in the context using core.SetClaims() and can be retrieved using core.GetClaims[T]().
2316type JWTMiddleware struct {
2417 core * core.Core
2518 errorHandler ErrorHandler
@@ -148,19 +141,7 @@ func (m *JWTMiddleware) applyDefaults() {
148141// }
149142// fmt.Println(claims.RegisteredClaims.Subject)
150143func GetClaims [T any ](ctx context.Context ) (T , error ) {
151- var zero T
152-
153- val := ctx .Value (claimsContextKey )
154- if val == nil {
155- return zero , fmt .Errorf ("claims not found in context" )
156- }
157-
158- claims , ok := val .(T )
159- if ! ok {
160- return zero , fmt .Errorf ("claims have wrong type: expected %T, got %T" , zero , val )
161- }
162-
163- return claims , nil
144+ return core.GetClaims [T ](ctx )
164145}
165146
166147// MustGetClaims retrieves claims from the context or panics.
@@ -171,7 +152,7 @@ func GetClaims[T any](ctx context.Context) (T, error) {
171152// claims := jwtmiddleware.MustGetClaims[*validator.ValidatedClaims](r.Context())
172153// fmt.Println(claims.RegisteredClaims.Subject)
173154func MustGetClaims [T any ](ctx context.Context ) T {
174- claims , err := GetClaims [T ](ctx )
155+ claims , err := core. GetClaims [T ](ctx )
175156 if err != nil {
176157 panic (err )
177158 }
@@ -187,7 +168,7 @@ func MustGetClaims[T any](ctx context.Context) T {
187168// // Use claims...
188169// }
189170func HasClaims (ctx context.Context ) bool {
190- return ctx . Value ( claimsContextKey ) != nil
171+ return core . HasClaims ( ctx )
191172}
192173
193174// CheckJWT is the main JWTMiddleware function which performs the main logic. It
@@ -267,7 +248,7 @@ func (m *JWTMiddleware) CheckJWT(next http.Handler) http.Handler {
267248 if m .logger != nil {
268249 m .logger .Debug ("JWT validation successful, setting claims in context" )
269250 }
270- r = r .Clone (context . WithValue (r .Context (), claimsContextKey , validToken ))
251+ r = r .Clone (core . SetClaims (r .Context (), validToken ))
271252 next .ServeHTTP (w , r )
272253 })
273254}
0 commit comments