Skip to content

Commit eeb74bd

Browse files
committed
Merge remote-tracking branch 'upstream/main' into feat/sep-985
2 parents b32c4c6 + 8cdac3d commit eeb74bd

File tree

25 files changed

+628
-218
lines changed

25 files changed

+628
-218
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
[protocol-badge]: https://img.shields.io/badge/protocol-modelcontextprotocol.io-blue.svg
8080
[protocol-url]: https://modelcontextprotocol.io
8181
[spec-badge]: https://img.shields.io/badge/spec-spec.modelcontextprotocol.io-blue.svg
82-
[spec-url]: https://spec.modelcontextprotocol.io
82+
[spec-url]: https://modelcontextprotocol.io/specification/latest
8383

8484
## Overview
8585

@@ -2433,7 +2433,7 @@ MCP servers declare capabilities during initialization:
24332433

24342434
- [API Reference](https://modelcontextprotocol.github.io/python-sdk/api/)
24352435
- [Model Context Protocol documentation](https://modelcontextprotocol.io)
2436-
- [Model Context Protocol specification](https://spec.modelcontextprotocol.io)
2436+
- [Model Context Protocol specification](https://modelcontextprotocol.io/specification/latest)
24372437
- [Officially supported servers](https://github.com/modelcontextprotocol/servers)
24382438

24392439
## Contributing

examples/clients/simple-auth-client/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ ignore = []
3939
line-length = 120
4040
target-version = "py310"
4141

42-
[tool.uv]
43-
dev-dependencies = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"]
42+
[dependency-groups]
43+
dev = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"]

examples/clients/simple-chatbot/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,5 @@ ignore = []
4444
line-length = 120
4545
target-version = "py310"
4646

47-
[tool.uv]
48-
dev-dependencies = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"]
47+
[dependency-groups]
48+
dev = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"]

examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
7373

7474
async def register_client(self, client_info: OAuthClientInformationFull):
7575
"""Register a new OAuth client."""
76+
if not client_info.client_id:
77+
raise ValueError("No client_id provided")
7678
self.clients[client_info.client_id] = client_info
7779

7880
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
@@ -209,6 +211,8 @@ async def exchange_authorization_code(
209211
"""Exchange authorization code for tokens."""
210212
if authorization_code.code not in self.auth_codes:
211213
raise ValueError("Invalid authorization code")
214+
if not client.client_id:
215+
raise ValueError("No client_id provided")
212216

213217
# Generate MCP access token
214218
mcp_token = f"mcp_{secrets.token_hex(32)}"

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"uvicorn>=0.31.1; sys_platform != 'emscripten'",
3434
"jsonschema>=4.20.0",
3535
"pywin32>=310; sys_platform == 'win32'",
36+
"pyjwt[crypto]>=2.10.1",
3637
]
3738

3839
[project.optional-dependencies]
@@ -98,7 +99,7 @@ venv = ".venv"
9899
# those private functions instead of testing the private functions directly. It makes it easier to maintain the code source
99100
# and refactor code that is not public.
100101
executionEnvironments = [
101-
{ root = "tests", reportUnusedFunction = false, reportPrivateUsage = false },
102+
{ root = "tests", extraPaths = ["."], reportUnusedFunction = false, reportPrivateUsage = false },
102103
{ root = "examples/servers", reportUnusedFunction = false },
103104
]
104105

src/mcp/client/auth/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
OAuth2 Authentication implementation for HTTPX.
3+
4+
Implements authorization code flow with PKCE and automatic token refresh.
5+
"""
6+
7+
from mcp.client.auth.oauth2 import (
8+
OAuthClientProvider,
9+
OAuthFlowError,
10+
OAuthRegistrationError,
11+
OAuthTokenError,
12+
PKCEParameters,
13+
TokenStorage,
14+
)
15+
16+
__all__ = [
17+
"OAuthClientProvider",
18+
"OAuthFlowError",
19+
"OAuthRegistrationError",
20+
"OAuthTokenError",
21+
"PKCEParameters",
22+
"TokenStorage",
23+
]

src/mcp/client/auth/extensions/__init__.py

Whitespace-only changes.
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import time
2+
from collections.abc import Awaitable, Callable
3+
from typing import Any
4+
from uuid import uuid4
5+
6+
import httpx
7+
import jwt
8+
from pydantic import BaseModel, Field
9+
10+
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
11+
from mcp.shared.auth import OAuthClientMetadata
12+
13+
14+
class JWTParameters(BaseModel):
15+
"""JWT parameters."""
16+
17+
assertion: str | None = Field(
18+
default=None,
19+
description="JWT assertion for JWT authentication. "
20+
"Will be used instead of generating a new assertion if provided.",
21+
)
22+
23+
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
24+
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
25+
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
26+
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
27+
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
28+
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
29+
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
30+
31+
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
32+
if self.assertion is not None:
33+
# Prebuilt JWT (e.g. acquired out-of-band)
34+
assertion = self.assertion
35+
else:
36+
if not self.jwt_signing_key:
37+
raise OAuthFlowError("Missing signing key for JWT bearer grant")
38+
if not self.issuer:
39+
raise OAuthFlowError("Missing issuer for JWT bearer grant")
40+
if not self.subject:
41+
raise OAuthFlowError("Missing subject for JWT bearer grant")
42+
43+
audience = self.audience if self.audience else with_audience_fallback
44+
if not audience:
45+
raise OAuthFlowError("Missing audience for JWT bearer grant")
46+
47+
now = int(time.time())
48+
claims: dict[str, Any] = {
49+
"iss": self.issuer,
50+
"sub": self.subject,
51+
"aud": audience,
52+
"exp": now + self.jwt_lifetime_seconds,
53+
"iat": now,
54+
"jti": str(uuid4()),
55+
}
56+
claims.update(self.claims or {})
57+
58+
assertion = jwt.encode(
59+
claims,
60+
self.jwt_signing_key,
61+
algorithm=self.jwt_signing_algorithm or "RS256",
62+
)
63+
return assertion
64+
65+
66+
class RFC7523OAuthClientProvider(OAuthClientProvider):
67+
"""OAuth client provider for RFC7532 clients."""
68+
69+
jwt_parameters: JWTParameters | None = None
70+
71+
def __init__(
72+
self,
73+
server_url: str,
74+
client_metadata: OAuthClientMetadata,
75+
storage: TokenStorage,
76+
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
77+
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
78+
timeout: float = 300.0,
79+
jwt_parameters: JWTParameters | None = None,
80+
) -> None:
81+
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
82+
self.jwt_parameters = jwt_parameters
83+
84+
async def _exchange_token_authorization_code(
85+
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
86+
) -> httpx.Request:
87+
"""Build token exchange request for authorization_code flow."""
88+
token_data = token_data or {}
89+
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
90+
self._add_client_authentication_jwt(token_data=token_data)
91+
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
92+
93+
async def _perform_authorization(self) -> httpx.Request:
94+
"""Perform the authorization flow."""
95+
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
96+
token_request = await self._exchange_token_jwt_bearer()
97+
return token_request
98+
else:
99+
return await super()._perform_authorization()
100+
101+
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]):
102+
"""Add JWT assertion for client authentication to token endpoint parameters."""
103+
if not self.jwt_parameters:
104+
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
105+
if not self.context.oauth_metadata:
106+
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
107+
108+
# We need to set the audience to the issuer identifier of the authorization server
109+
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
110+
issuer = str(self.context.oauth_metadata.issuer)
111+
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
112+
113+
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
114+
token_data["client_assertion"] = assertion
115+
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
116+
# We need to set the audience to the resource server, the audience is difference from the one in claims
117+
# it represents the resource server that will validate the token
118+
token_data["audience"] = self.context.get_resource_url()
119+
120+
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
121+
"""Build token exchange request for JWT bearer grant."""
122+
if not self.context.client_info:
123+
raise OAuthFlowError("Missing client info")
124+
if not self.jwt_parameters:
125+
raise OAuthFlowError("Missing JWT parameters")
126+
if not self.context.oauth_metadata:
127+
raise OAuthTokenError("Missing OAuth metadata")
128+
129+
# We need to set the audience to the issuer identifier of the authorization server
130+
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
131+
issuer = str(self.context.oauth_metadata.issuer)
132+
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
133+
134+
token_data = {
135+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
136+
"assertion": assertion,
137+
}
138+
139+
if self.context.should_include_resource_param(self.context.protocol_version):
140+
token_data["resource"] = self.context.get_resource_url()
141+
142+
if self.context.client_metadata.scope:
143+
token_data["scope"] = self.context.client_metadata.scope
144+
145+
token_url = self._get_token_endpoint()
146+
return httpx.Request(
147+
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
148+
)

src/mcp/client/auth.py renamed to src/mcp/client/auth/oauth2.py

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import time
1414
from collections.abc import AsyncGenerator, Awaitable, Callable
1515
from dataclasses import dataclass, field
16-
from typing import Protocol
16+
from typing import Any, Protocol
1717
from urllib.parse import urlencode, urljoin, urlparse
1818

1919
import anyio
@@ -88,8 +88,8 @@ class OAuthContext:
8888
server_url: str
8989
client_metadata: OAuthClientMetadata
9090
storage: TokenStorage
91-
redirect_handler: Callable[[str], Awaitable[None]]
92-
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]]
91+
redirect_handler: Callable[[str], Awaitable[None]] | None
92+
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
9393
timeout: float = 300.0
9494

9595
# Discovered metadata
@@ -194,8 +194,8 @@ def __init__(
194194
server_url: str,
195195
client_metadata: OAuthClientMetadata,
196196
storage: TokenStorage,
197-
redirect_handler: Callable[[str], Awaitable[None]],
198-
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]],
197+
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
198+
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
199199
timeout: float = 300.0,
200200
):
201201
"""Initialize OAuth2 authentication."""
@@ -423,8 +423,21 @@ async def _handle_registration_response(self, response: httpx.Response) -> None:
423423
except ValidationError as e:
424424
raise OAuthRegistrationError(f"Invalid registration response: {e}")
425425

426-
async def _perform_authorization(self) -> tuple[str, str]:
426+
async def _perform_authorization(self) -> httpx.Request:
427+
"""Perform the authorization flow."""
428+
auth_code, code_verifier = await self._perform_authorization_code_grant()
429+
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
430+
return token_request
431+
432+
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
427433
"""Perform the authorization redirect and get auth code."""
434+
if self.context.client_metadata.redirect_uris is None:
435+
raise OAuthFlowError("No redirect URIs provided for authorization code grant")
436+
if not self.context.redirect_handler:
437+
raise OAuthFlowError("No redirect handler provided for authorization code grant")
438+
if not self.context.callback_handler:
439+
raise OAuthFlowError("No callback handler provided for authorization code grant")
440+
428441
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
429442
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint)
430443
else:
@@ -469,24 +482,34 @@ async def _perform_authorization(self) -> tuple[str, str]:
469482
# Return auth code and code verifier for token exchange
470483
return auth_code, pkce_params.code_verifier
471484

472-
async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Request:
473-
"""Build token exchange request."""
474-
if not self.context.client_info:
475-
raise OAuthFlowError("Missing client info")
476-
485+
def _get_token_endpoint(self) -> str:
477486
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
478487
token_url = str(self.context.oauth_metadata.token_endpoint)
479488
else:
480489
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
481490
token_url = urljoin(auth_base_url, "/token")
491+
return token_url
492+
493+
async def _exchange_token_authorization_code(
494+
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
495+
) -> httpx.Request:
496+
"""Build token exchange request for authorization_code flow."""
497+
if self.context.client_metadata.redirect_uris is None:
498+
raise OAuthFlowError("No redirect URIs provided for authorization code grant")
499+
if not self.context.client_info:
500+
raise OAuthFlowError("Missing client info")
482501

483-
token_data = {
484-
"grant_type": "authorization_code",
485-
"code": auth_code,
486-
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
487-
"client_id": self.context.client_info.client_id,
488-
"code_verifier": code_verifier,
489-
}
502+
token_url = self._get_token_endpoint()
503+
token_data = token_data or {}
504+
token_data.update(
505+
{
506+
"grant_type": "authorization_code",
507+
"code": auth_code,
508+
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
509+
"client_id": self.context.client_info.client_id,
510+
"code_verifier": code_verifier,
511+
}
512+
)
490513

491514
# Only include resource param if conditions are met
492515
if self.context.should_include_resource_param(self.context.protocol_version):
@@ -502,7 +525,9 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req
502525
async def _handle_token_response(self, response: httpx.Response) -> None:
503526
"""Handle token exchange response."""
504527
if response.status_code != 200:
505-
raise OAuthTokenError(f"Token exchange failed: {response.status_code}")
528+
body = await response.aread()
529+
body = body.decode("utf-8")
530+
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}")
506531

507532
try:
508533
content = await response.aread()
@@ -663,12 +688,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
663688
registration_response = yield registration_request
664689
await self._handle_registration_response(registration_response)
665690

666-
# Step 5: Perform authorization
667-
auth_code, code_verifier = await self._perform_authorization()
668-
669-
# Step 6: Exchange authorization code for tokens
670-
token_request = await self._exchange_token(auth_code, code_verifier)
671-
token_response = yield token_request
691+
# Step 5: Perform authorization and complete token exchange
692+
token_response = yield await self._perform_authorization()
672693
await self._handle_token_response(token_response)
673694
except Exception:
674695
logger.exception("OAuth flow error")
@@ -687,17 +708,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
687708
# Step 2a: Update the required scopes
688709
self._select_scopes(response)
689710

690-
# Step 2b: Perform (re-)authorization
691-
auth_code, code_verifier = await self._perform_authorization()
692-
693-
# Step 2c: Exchange authorization code for tokens
694-
token_request = await self._exchange_token(auth_code, code_verifier)
695-
token_response = yield token_request
711+
# Step 2b: Perform (re-)authorization and token exchange
712+
token_response = yield await self._perform_authorization()
696713
await self._handle_token_response(token_response)
697714
except Exception:
698715
logger.exception("OAuth flow error")
699716
raise
700717

701-
# Retry with new tokens
702-
self._add_auth_header(request)
703-
yield request
718+
# Retry with new tokens
719+
self._add_auth_header(request)
720+
yield request

0 commit comments

Comments
 (0)