diff --git a/core/context.go b/core/context.go new file mode 100644 index 0000000..f89048f --- /dev/null +++ b/core/context.go @@ -0,0 +1,55 @@ +package core + +import "context" + +// contextKey is an unexported type for context keys to prevent collisions. +// Using an unexported type ensures that only this package can create context keys, +// eliminating the risk of collisions with other packages. +type contextKey int + +const ( + claimsKey contextKey = iota +) + +// GetClaims retrieves claims from the context with type safety using generics. +// +// This is a type-safe alternative to manually type-asserting the claims from the context. +// It returns an error if the claims are not found or if the type assertion fails. +// +// Example usage: +// +// claims, err := core.GetClaims[*validator.ValidatedClaims](ctx) +// if err != nil { +// return err +// } +// // Use claims... +func GetClaims[T any](ctx context.Context) (T, error) { + var zero T + + val := ctx.Value(claimsKey) + if val == nil { + return zero, ErrClaimsNotFound + } + + claims, ok := val.(T) + if !ok { + return zero, NewValidationError( + ErrorCodeClaimsNotFound, + "claims type assertion failed", + nil, + ) + } + + return claims, nil +} + +// SetClaims stores claims in the context. +// This is a helper function for adapters to set claims after validation. +func SetClaims(ctx context.Context, claims any) context.Context { + return context.WithValue(ctx, claimsKey, claims) +} + +// HasClaims checks if claims exist in the context without retrieving them. +func HasClaims(ctx context.Context) bool { + return ctx.Value(claimsKey) != nil +} diff --git a/core/core.go b/core/core.go new file mode 100644 index 0000000..07e2d73 --- /dev/null +++ b/core/core.go @@ -0,0 +1,81 @@ +// Package core provides framework-agnostic JWT validation logic that can be used +// across different transport layers (HTTP, gRPC, etc.). +// +// The Core type encapsulates the validation logic and can be wrapped by transport-specific +// adapters to provide JWT middleware functionality for various frameworks. +package core + +import ( + "context" + "time" +) + +// TokenValidator defines the interface for JWT validation. +// Implementations should validate the token and return the validated claims. +type TokenValidator interface { + ValidateToken(ctx context.Context, token string) (any, error) +} + +// Logger defines an optional logging interface for the core middleware. +type Logger interface { + Debug(msg string, args ...any) + Info(msg string, args ...any) + Warn(msg string, args ...any) + Error(msg string, args ...any) +} + +// Core is the framework-agnostic JWT validation engine. +// It contains the core logic for token validation without any dependency +// on specific transport protocols (HTTP, gRPC, etc.). +type Core struct { + validator TokenValidator + credentialsOptional bool + logger Logger +} + +// CheckToken validates a JWT token string and returns the validated claims. +// +// This is the core validation logic that is framework-agnostic: +// - If token is empty and credentialsOptional is true, returns (nil, nil) +// - If token is empty and credentialsOptional is false, returns ErrJWTMissing +// - Otherwise, validates the token using the configured validator +// +// The returned claims (any) should be type-asserted by the caller +// to the expected claims type (typically *validator.ValidatedClaims). +func (c *Core) CheckToken(ctx context.Context, token string) (any, error) { + // Handle empty token case + if token == "" { + if c.credentialsOptional { + if c.logger != nil { + c.logger.Debug("No token provided, but credentials are optional") + } + return nil, nil + } + + if c.logger != nil { + c.logger.Warn("No token provided and credentials are required") + } + + return nil, ErrJWTMissing + } + + // Validate token + start := time.Now() + claims, err := c.validator.ValidateToken(ctx, token) + duration := time.Since(start) + + if err != nil { + if c.logger != nil { + c.logger.Error("Token validation failed", "error", err, "duration", duration) + } + + return nil, err + } + + // Success + if c.logger != nil { + c.logger.Debug("Token validated successfully", "duration", duration) + } + + return claims, nil +} diff --git a/core/core_test.go b/core/core_test.go new file mode 100644 index 0000000..1e4d858 --- /dev/null +++ b/core/core_test.go @@ -0,0 +1,367 @@ +package core + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockValidator is a mock implementation of TokenValidator for testing. +type mockValidator struct { + validateFunc func(ctx context.Context, token string) (any, error) +} + +func (m *mockValidator) ValidateToken(ctx context.Context, token string) (any, error) { + if m.validateFunc != nil { + return m.validateFunc(ctx, token) + } + return nil, errors.New("not implemented") +} + +// mockLogger is a mock implementation of Logger for testing. +type mockLogger struct { + debugCalls []logCall + infoCalls []logCall + warnCalls []logCall + errorCalls []logCall +} + +type logCall struct { + msg string + args []any +} + +func (m *mockLogger) Debug(msg string, args ...any) { + m.debugCalls = append(m.debugCalls, logCall{msg, args}) +} + +func (m *mockLogger) Info(msg string, args ...any) { + m.infoCalls = append(m.infoCalls, logCall{msg, args}) +} + +func (m *mockLogger) Warn(msg string, args ...any) { + m.warnCalls = append(m.warnCalls, logCall{msg, args}) +} + +func (m *mockLogger) Error(msg string, args ...any) { + m.errorCalls = append(m.errorCalls, logCall{msg, args}) +} + +func TestNew(t *testing.T) { + validator := &mockValidator{ + validateFunc: func(ctx context.Context, token string) (any, error) { + return "claims", nil + }, + } + + t.Run("successful creation with required options", func(t *testing.T) { + core, err := New(WithValidator(validator)) + require.NoError(t, err) + assert.NotNil(t, core) + assert.False(t, core.credentialsOptional) // Default is false + }) + + t.Run("successful creation with all options", func(t *testing.T) { + logger := &mockLogger{} + core, err := New( + WithValidator(validator), + WithCredentialsOptional(true), + WithLogger(logger), + ) + require.NoError(t, err) + assert.NotNil(t, core) + assert.True(t, core.credentialsOptional) + assert.NotNil(t, core.logger) + }) + + t.Run("error when validator is missing", func(t *testing.T) { + core, err := New() + assert.Error(t, err) + assert.Nil(t, core) + assert.Contains(t, err.Error(), "validator is required") + }) + + t.Run("error when validator is nil", func(t *testing.T) { + core, err := New(WithValidator(nil)) + assert.Error(t, err) + assert.Nil(t, core) + assert.Contains(t, err.Error(), "validator cannot be nil") + }) + + t.Run("error when logger is nil", func(t *testing.T) { + core, err := New( + WithValidator(validator), + WithLogger(nil), + ) + assert.Error(t, err) + assert.Nil(t, core) + assert.Contains(t, err.Error(), "logger cannot be nil") + }) +} + +func TestCore_CheckToken(t *testing.T) { + t.Run("successful validation", func(t *testing.T) { + expectedClaims := map[string]any{"sub": "user123"} + validator := &mockValidator{ + validateFunc: func(ctx context.Context, token string) (any, error) { + return expectedClaims, nil + }, + } + + core, err := New(WithValidator(validator)) + require.NoError(t, err) + + claims, err := core.CheckToken(context.Background(), "valid-token") + assert.NoError(t, err) + assert.Equal(t, expectedClaims, claims) + }) + + t.Run("validation error", func(t *testing.T) { + expectedErr := errors.New("invalid signature") + validator := &mockValidator{ + validateFunc: func(ctx context.Context, token string) (any, error) { + return nil, expectedErr + }, + } + + core, err := New(WithValidator(validator)) + require.NoError(t, err) + + claims, err := core.CheckToken(context.Background(), "invalid-token") + assert.Error(t, err) + assert.Nil(t, claims) + assert.Equal(t, expectedErr, err) + }) + + t.Run("empty token with credentials required", func(t *testing.T) { + validator := &mockValidator{ + validateFunc: func(ctx context.Context, token string) (any, error) { + t.Fatal("validator should not be called with empty token") + return nil, nil + }, + } + + core, err := New( + WithValidator(validator), + WithCredentialsOptional(false), // Explicit false + ) + require.NoError(t, err) + + claims, err := core.CheckToken(context.Background(), "") + assert.Error(t, err) + assert.Nil(t, claims) + assert.Equal(t, ErrJWTMissing, err) + }) + + t.Run("empty token with credentials optional", func(t *testing.T) { + validator := &mockValidator{ + validateFunc: func(ctx context.Context, token string) (any, error) { + t.Fatal("validator should not be called with empty token") + return nil, nil + }, + } + + core, err := New( + WithValidator(validator), + WithCredentialsOptional(true), + ) + require.NoError(t, err) + + claims, err := core.CheckToken(context.Background(), "") + assert.NoError(t, err) + assert.Nil(t, claims) + }) + + t.Run("logger integration on success", func(t *testing.T) { + validator := &mockValidator{ + validateFunc: func(ctx context.Context, token string) (any, error) { + return "claims", nil + }, + } + logger := &mockLogger{} + + core, err := New( + WithValidator(validator), + WithLogger(logger), + ) + require.NoError(t, err) + + _, err = core.CheckToken(context.Background(), "valid-token") + assert.NoError(t, err) + + // Should log successful validation + assert.Len(t, logger.debugCalls, 1) + assert.Contains(t, logger.debugCalls[0].msg, "validated successfully") + }) + + t.Run("logger integration on error", func(t *testing.T) { + validator := &mockValidator{ + validateFunc: func(ctx context.Context, token string) (any, error) { + return nil, errors.New("validation failed") + }, + } + logger := &mockLogger{} + + core, err := New( + WithValidator(validator), + WithLogger(logger), + ) + require.NoError(t, err) + + _, err = core.CheckToken(context.Background(), "invalid-token") + assert.Error(t, err) + + // Should log validation error + assert.Len(t, logger.errorCalls, 1) + assert.Contains(t, logger.errorCalls[0].msg, "validation failed") + }) + + t.Run("logger integration on missing token", func(t *testing.T) { + validator := &mockValidator{} + logger := &mockLogger{} + + core, err := New( + WithValidator(validator), + WithLogger(logger), + ) + require.NoError(t, err) + + _, err = core.CheckToken(context.Background(), "") + assert.Error(t, err) + + // Should log warning + assert.Len(t, logger.warnCalls, 1) + assert.Contains(t, logger.warnCalls[0].msg, "credentials are required") + }) + + t.Run("logger integration on optional credentials", func(t *testing.T) { + validator := &mockValidator{} + logger := &mockLogger{} + + core, err := New( + WithValidator(validator), + WithCredentialsOptional(true), + WithLogger(logger), + ) + require.NoError(t, err) + + _, err = core.CheckToken(context.Background(), "") + assert.NoError(t, err) + + // Should log debug message + assert.Len(t, logger.debugCalls, 1) + assert.Contains(t, logger.debugCalls[0].msg, "credentials are optional") + }) +} + +func TestCore_CheckToken_Context(t *testing.T) { + t.Run("context is passed to validator", func(t *testing.T) { + type ctxKey struct{} + expectedValue := "test-value" + ctx := context.WithValue(context.Background(), ctxKey{}, expectedValue) + + var receivedCtx context.Context + validator := &mockValidator{ + validateFunc: func(ctx context.Context, token string) (any, error) { + receivedCtx = ctx + return "claims", nil + }, + } + + core, err := New(WithValidator(validator)) + require.NoError(t, err) + + _, err = core.CheckToken(ctx, "token") + assert.NoError(t, err) + + // Verify context was passed through + assert.Equal(t, expectedValue, receivedCtx.Value(ctxKey{})) + }) +} + +func TestContextHelpers(t *testing.T) { + t.Run("SetClaims and GetClaims", func(t *testing.T) { + type testClaims struct { + Sub string + Aud string + } + + claims := &testClaims{ + Sub: "user123", + Aud: "api", + } + + ctx := context.Background() + ctx = SetClaims(ctx, claims) + + retrieved, err := GetClaims[*testClaims](ctx) + assert.NoError(t, err) + assert.Equal(t, claims, retrieved) + }) + + t.Run("GetClaims with wrong type", func(t *testing.T) { + type wrongType struct{} + + ctx := context.Background() + ctx = SetClaims(ctx, "string-claims") + + retrieved, err := GetClaims[*wrongType](ctx) + assert.Error(t, err) + assert.Nil(t, retrieved) + assert.Contains(t, err.Error(), "type assertion failed") + }) + + t.Run("GetClaims from empty context", func(t *testing.T) { + ctx := context.Background() + + claims, err := GetClaims[string](ctx) + assert.Error(t, err) + assert.Equal(t, "", claims) + assert.Equal(t, ErrClaimsNotFound, err) + }) + + t.Run("HasClaims returns true when claims exist", func(t *testing.T) { + ctx := context.Background() + ctx = SetClaims(ctx, "claims") + + assert.True(t, HasClaims(ctx)) + }) + + t.Run("HasClaims returns false when claims don't exist", func(t *testing.T) { + ctx := context.Background() + + assert.False(t, HasClaims(ctx)) + }) +} + +func TestValidationError(t *testing.T) { + t.Run("error message with details", func(t *testing.T) { + details := errors.New("signature invalid") + err := NewValidationError(ErrorCodeInvalidSignature, "token signature verification failed", details) + + assert.Contains(t, err.Error(), "token signature verification failed") + assert.Contains(t, err.Error(), "signature invalid") + }) + + t.Run("error message without details", func(t *testing.T) { + err := NewValidationError(ErrorCodeTokenMissing, "token is missing", nil) + + assert.Equal(t, "token is missing", err.Error()) + }) + + t.Run("Unwrap returns details", func(t *testing.T) { + details := errors.New("underlying error") + err := NewValidationError(ErrorCodeInvalidClaims, "validation failed", details) + + assert.Equal(t, details, errors.Unwrap(err)) + }) + + t.Run("Is works with ErrJWTInvalid", func(t *testing.T) { + err := NewValidationError(ErrorCodeInvalidSignature, "bad signature", nil) + + assert.True(t, errors.Is(err, ErrJWTInvalid)) + }) +} diff --git a/core/errors.go b/core/errors.go new file mode 100644 index 0000000..e168310 --- /dev/null +++ b/core/errors.go @@ -0,0 +1,75 @@ +package core + +import "errors" + +// Sentinel errors for JWT validation. +var ( + // ErrJWTMissing is returned when the JWT is missing from the request. + ErrJWTMissing = errors.New("jwt missing") + + // ErrJWTInvalid is returned when the JWT is invalid. + // This is typically wrapped with more specific validation errors. + ErrJWTInvalid = errors.New("jwt invalid") + + // ErrClaimsNotFound is returned when claims cannot be retrieved from context. + ErrClaimsNotFound = errors.New("claims not found in context") +) + +// ValidationError wraps JWT validation errors with additional context. +// It provides structured error information that can be used for +// logging, metrics, and returning appropriate error responses. +type ValidationError struct { + // Code is a machine-readable error code (e.g., "token_expired", "invalid_signature") + Code string + + // Message is a human-readable error message + Message string + + // Details contains the underlying error + Details error +} + +// Error implements the error interface. +func (e *ValidationError) Error() string { + if e.Details != nil { + return e.Message + ": " + e.Details.Error() + } + return e.Message +} + +// Unwrap returns the underlying error for error unwrapping. +func (e *ValidationError) Unwrap() error { + return e.Details +} + +// Is allows the error to be compared with ErrJWTInvalid. +func (e *ValidationError) Is(target error) bool { + return target == ErrJWTInvalid +} + +// Common error codes +const ( + ErrorCodeTokenMissing = "token_missing" + ErrorCodeTokenMalformed = "token_malformed" + ErrorCodeTokenExpired = "token_expired" + ErrorCodeTokenNotYetValid = "token_not_yet_valid" + ErrorCodeInvalidSignature = "invalid_signature" + ErrorCodeInvalidAlgorithm = "invalid_algorithm" + ErrorCodeInvalidIssuer = "invalid_issuer" + ErrorCodeInvalidAudience = "invalid_audience" + ErrorCodeInvalidClaims = "invalid_claims" + ErrorCodeJWKSFetchFailed = "jwks_fetch_failed" + ErrorCodeJWKSKeyNotFound = "jwks_key_not_found" + ErrorCodeConfigInvalid = "config_invalid" + ErrorCodeValidatorNotSet = "validator_not_set" + ErrorCodeClaimsNotFound = "claims_not_found" +) + +// NewValidationError creates a new ValidationError with the given code and message. +func NewValidationError(code, message string, details error) *ValidationError { + return &ValidationError{ + Code: code, + Message: message, + Details: details, + } +} diff --git a/core/option.go b/core/option.go new file mode 100644 index 0000000..7afac49 --- /dev/null +++ b/core/option.go @@ -0,0 +1,108 @@ +package core + +import ( + "errors" +) + +// Option is a function that configures the Core. +// Options return errors to enable validation during construction. +type Option func(*Core) error + +// New creates a new Core instance with the provided options. +// +// The Core must be configured with at least a TokenValidator using WithValidator. +// All other options are optional and will use sensible defaults if not provided. +// +// Example: +// +// core, err := core.New( +// core.WithValidator(validator), +// core.WithCredentialsOptional(true), +// core.WithLogger(logger), +// ) +// if err != nil { +// log.Fatal(err) +// } +func New(opts ...Option) (*Core, error) { + c := &Core{ + credentialsOptional: false, // Secure default: require credentials + } + + // Apply all options + for _, opt := range opts { + if err := opt(c); err != nil { + return nil, err + } + } + + // Validate required configuration + if err := c.validate(); err != nil { + return nil, err + } + + return c, nil +} + +// validate ensures all required fields are set. +func (c *Core) validate() error { + if c.validator == nil { + return NewValidationError( + ErrorCodeValidatorNotSet, + "validator is required but not set (use WithValidator option)", + nil, + ) + } + return nil +} + +// WithValidator sets the token validator for the Core. +// This is a required option. +func WithValidator(validator TokenValidator) Option { + return func(c *Core) error { + if validator == nil { + return errors.New("validator cannot be nil") + } + c.validator = validator + return nil + } +} + +// WithCredentialsOptional configures whether credentials are optional. +// +// When set to true, requests without tokens will be allowed to proceed +// without validation. The claims will be nil in the context. +// +// When set to false (default), requests without tokens will return ErrJWTMissing. +// +// Use this option carefully - requiring authentication by default is more secure. +func WithCredentialsOptional(optional bool) Option { + return func(c *Core) error { + c.credentialsOptional = optional + return nil + } +} + +// WithLogger sets an optional logger for the Core. +// +// When configured, the Core will log debug information about token +// extraction, validation success/failure, and timing information. +// +// If you need custom metrics or callbacks, consider wrapping the Core +// in your own implementation that delegates to the Core for validation. +// +// Example: +// +// logger := slog.Default() +// core, _ := core.New( +// core.WithValidator(validator), +// core.WithLogger(logger), +// ) +func WithLogger(logger Logger) Option { + return func(c *Core) error { + if logger == nil { + return errors.New("logger cannot be nil") + } + c.logger = logger + return nil + } +}