Skip to content

Commit 0e95ba3

Browse files
authored
Merge pull request #133 from QuantGeekDev/feat/auth-context
Feat/auth context
2 parents d503cc5 + 182a019 commit 0e95ba3

File tree

8 files changed

+98
-25
lines changed

8 files changed

+98
-25
lines changed

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "mcp-framework",
3-
"version": "0.2.16",
3+
"version": "0.2.17-beta.2",
44
"description": "Framework for building Model Context Protocol (MCP) servers in Typescript",
55
"type": "module",
66
"author": "Alex Andru <[email protected]>",

src/auth/providers/oauth.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ export class OAuthAuthProvider implements AuthProvider {
8181
logger.debug(`Token claims - sub: ${claims.sub}, scope: ${claims.scope || 'N/A'}`);
8282

8383
return {
84-
data: claims,
84+
data: { ...claims, token },
8585
};
8686
} catch (error) {
8787
if (error instanceof Error) {

src/auth/validators/jwt-validator.ts

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,11 @@ export class JWTValidator {
112112
};
113113

114114
// Only validate audience if not set to wildcard
115-
if (this.config.audience !== '*') {
116-
options.audience = this.config.audience;
117-
}
115+
// For Cognito Access Tokens, 'aud' is missing but 'client_id' is present.
116+
// We disable the library's strict check and handle it manually in the callback.
117+
// if (this.config.audience !== '*') {
118+
// options.audience = this.config.audience;
119+
// }
118120

119121
jwt.verify(token, publicKey, options, (err, decoded) => {
120122
if (err) {
@@ -151,10 +153,32 @@ export class JWTValidator {
151153
return;
152154
}
153155

154-
// Only require aud claim if not set to wildcard
155-
if (this.config.audience !== '*' && !claims.aud) {
156-
reject(new Error('Token missing required claim: aud'));
157-
return;
156+
// Only require aud/client_id claim if not set to wildcard
157+
if (this.config.audience !== '*') {
158+
const aud = claims.aud;
159+
const clientId = claims.client_id as string | undefined;
160+
const expectedAudience = this.config.audience;
161+
162+
let isValidAudience = false;
163+
164+
// Check 'aud' claim (ID Tokens)
165+
if (aud) {
166+
if (Array.isArray(aud)) {
167+
if (aud.includes(expectedAudience)) isValidAudience = true;
168+
} else {
169+
if (aud === expectedAudience) isValidAudience = true;
170+
}
171+
}
172+
173+
// Check 'client_id' claim (Access Tokens)
174+
if (!isValidAudience && clientId) {
175+
if (clientId === expectedAudience) isValidAudience = true;
176+
}
177+
178+
if (!isValidAudience) {
179+
reject(new Error(`Token audience mismatch. Expected ${expectedAudience}, got aud: ${aud}, client_id: ${clientId}`));
180+
return;
181+
}
158182
}
159183

160184
if (!claims.exp) {

src/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ export * from './auth/index.js';
1010
export type { SSETransportConfig } from './transports/sse/types.js';
1111
export type { HttpStreamTransportConfig } from './transports/http/types.js';
1212
export { HttpStreamTransport } from './transports/http/server.js';
13+
14+
export { requestContext, getRequestContext, runInRequestContext } from './utils/requestContext.js';
15+
export type { RequestContextData } from './utils/requestContext.js';

src/transports/http/server.ts

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import { logger } from '../../core/Logger.js';
88
import { ProtectedResourceMetadata } from '../../auth/metadata/protected-resource.js';
99
import { handleAuthentication } from '../utils/auth-handler.js';
1010
import { initializeOAuthMetadata } from '../utils/oauth-metadata.js';
11+
import { requestContext, RequestContextData } from '../../utils/requestContext.js';
12+
import { AuthResult } from '../../auth/types.js';
1113

1214
export class HttpStreamTransport extends AbstractTransport {
1315
readonly type = 'http-stream';
@@ -120,14 +122,17 @@ export class HttpStreamTransport extends AbstractTransport {
120122

121123
// Perform authentication check once at the beginning
122124
const authEndpoint = isInitialize ? 'sse' : 'messages';
125+
let authData: RequestContextData = {};
126+
123127
if (this._config.auth?.endpoints?.[authEndpoint] !== false) {
124-
const isAuthenticated = await handleAuthentication(
128+
const authResult = await handleAuthentication(
125129
req,
126130
res,
127131
this._config.auth,
128132
isInitialize ? 'initialize' : 'message'
129133
);
130-
if (!isAuthenticated) return;
134+
if (!authResult) return;
135+
authData = (authResult as AuthResult).data as RequestContextData || {};
131136
}
132137

133138
// Handle different request scenarios
@@ -168,7 +173,9 @@ export class HttpStreamTransport extends AbstractTransport {
168173
}
169174
};
170175

171-
await transport.handleRequest(req, res, body);
176+
await requestContext.run(authData, async () => {
177+
await transport.handleRequest(req, res, body);
178+
});
172179
return;
173180
} else if (!sessionId) {
174181
// No session ID and not an initialize request
@@ -181,7 +188,9 @@ export class HttpStreamTransport extends AbstractTransport {
181188
}
182189

183190
// Existing session - handle request
184-
await transport.handleRequest(req, res, body);
191+
await requestContext.run(authData, async () => {
192+
await transport.handleRequest(req, res, body);
193+
});
185194
}
186195

187196
private async readRequestBody(req: IncomingMessage): Promise<any> {

src/transports/sse/server.ts

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import { PING_SSE_MESSAGE } from "../utils/ping-message.js";
1111
import { ProtectedResourceMetadata } from "../../auth/metadata/protected-resource.js";
1212
import { handleAuthentication } from "../utils/auth-handler.js";
1313
import { initializeOAuthMetadata } from "../utils/oauth-metadata.js";
14+
import { requestContext, RequestContextData } from "../../utils/requestContext.js";
15+
import { AuthResult } from "../../auth/types.js";
1416

1517
interface ExtendedIncomingMessage extends IncomingMessage {
1618
body?: ClientRequest;
@@ -167,9 +169,12 @@ export class SSEServerTransport extends AbstractTransport {
167169
}
168170

169171
if (req.method === "POST" && url.pathname === this._config.messageEndpoint) {
172+
let authData: RequestContextData = {};
173+
170174
if (this._config.auth?.endpoints?.messages !== false) {
171-
const isAuthenticated = await handleAuthentication(req, res, this._config.auth, "message")
172-
if (!isAuthenticated) return
175+
const authResult = await handleAuthentication(req, res, this._config.auth, "message")
176+
if (!authResult) return
177+
authData = (authResult as AuthResult).data as RequestContextData || {};
173178
}
174179

175180
// **Connection Validation (User Requested):**
@@ -183,7 +188,7 @@ export class SSEServerTransport extends AbstractTransport {
183188
return;
184189
}
185190

186-
await this.handlePostMessage(req, res)
191+
await this.handlePostMessage(req, res, authData)
187192
return
188193
}
189194

@@ -250,7 +255,7 @@ export class SSEServerTransport extends AbstractTransport {
250255
logger.info(`SSE connection established successfully: ${connectionId}`);
251256
}
252257

253-
private async handlePostMessage(req: IncomingMessage, res: ServerResponse): Promise<void> {
258+
private async handlePostMessage(req: IncomingMessage, res: ServerResponse, authData: RequestContextData = {}): Promise<void> {
254259
// Check if *any* connection is active, not just the old single _sseResponse
255260
if (this._connections.size === 0) {
256261
logger.warn(`Rejecting message: no active SSE connections for server session ${this._sessionId}`);
@@ -301,7 +306,9 @@ export class SSEServerTransport extends AbstractTransport {
301306
throw new Error("No message handler registered")
302307
}
303308

304-
await this._onmessage(rpcMessage)
309+
await requestContext.run(authData, async () => {
310+
await this._onmessage!(rpcMessage)
311+
})
305312

306313
res.writeHead(202).end("Accepted")
307314

src/transports/utils/auth-handler.ts

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { IncomingMessage, ServerResponse } from 'node:http';
2-
import { AuthConfig } from '../../auth/types.js';
2+
import { AuthConfig, AuthResult } from '../../auth/types.js';
33
import { APIKeyAuthProvider } from '../../auth/providers/apikey.js';
44
import { OAuthAuthProvider } from '../../auth/providers/oauth.js';
55
import { DEFAULT_AUTH_ERROR } from '../../auth/types.js';
@@ -14,16 +14,16 @@ import { logger } from '../../core/Logger.js';
1414
* @param res - HTTP response object
1515
* @param authConfig - Authentication configuration from transport
1616
* @param context - Description of the context (e.g., "initialize", "message", "SSE connection")
17-
* @returns True if authenticated, false if authentication failed (response already sent)
17+
* @returns AuthResult if authenticated, null if authentication failed (response already sent)
1818
*/
1919
export async function handleAuthentication(
2020
req: IncomingMessage,
2121
res: ServerResponse,
2222
authConfig: AuthConfig | undefined,
2323
context: string
24-
): Promise<boolean> {
24+
): Promise<AuthResult | null> {
2525
if (!authConfig?.provider) {
26-
return true;
26+
return { data: {} };
2727
}
2828

2929
const isApiKey = authConfig.provider instanceof APIKeyAuthProvider;
@@ -43,7 +43,7 @@ export async function handleAuthentication(
4343
type: 'authentication_error',
4444
})
4545
);
46-
return false;
46+
return null;
4747
}
4848
}
4949

@@ -72,12 +72,16 @@ export async function handleAuthentication(
7272
type: 'authentication_error',
7373
})
7474
);
75-
return false;
75+
return null;
7676
}
7777

7878
// Authentication successful
7979
logger.info(`Authentication successful for ${context}:`);
8080
logger.info(`- Client IP: ${req.socket.remoteAddress}`);
8181
logger.info(`- Auth Type: ${authConfig.provider.constructor.name}`);
82-
return true;
82+
83+
if (typeof authResult === 'boolean') {
84+
return { data: {} };
85+
}
86+
return authResult;
8387
}

src/utils/requestContext.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import { AsyncLocalStorage } from 'async_hooks';
2+
3+
export interface RequestContextData {
4+
token?: string; // The raw token
5+
user?: Record<string, unknown>; // Decoded user data/claims
6+
[key: string]: unknown;
7+
}
8+
9+
export const requestContext = new AsyncLocalStorage<RequestContextData>();
10+
11+
/**
12+
* Get the current request context.
13+
* Returns undefined if called outside of a request context.
14+
*/
15+
export function getRequestContext(): RequestContextData | undefined {
16+
return requestContext.getStore();
17+
}
18+
19+
/**
20+
* Run a function within a request context.
21+
*/
22+
export function runInRequestContext<T>(context: RequestContextData, fn: () => T): T {
23+
return requestContext.run(context, fn);
24+
}
25+
26+

0 commit comments

Comments
 (0)