Skip to content

Commit fca00d1

Browse files
authored
Merge branch 'main' into feat/instructions-docs
2 parents d079d2c + 202af49 commit fca00d1

File tree

23 files changed

+559
-210
lines changed

23 files changed

+559
-210
lines changed

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ async def _default_redirect_handler(authorization_url: str) -> None:
187187

188188
# Create OAuth authentication handler using the new interface
189189
oauth_auth = OAuthClientProvider(
190-
server_url=self.server_url.replace("/mcp", ""),
190+
server_url=self.server_url,
191191
client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict),
192192
storage=InMemoryTokenStorage(),
193193
redirect_handler=_default_redirect_handler,

examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
7373

7474
async def register_client(self, client_info: OAuthClientInformationFull):
7575
"""Register a new OAuth client."""
76+
if not client_info.client_id:
77+
raise ValueError("No client_id provided")
7678
self.clients[client_info.client_id] = client_info
7779

7880
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
@@ -209,6 +211,8 @@ async def exchange_authorization_code(
209211
"""Exchange authorization code for tokens."""
210212
if authorization_code.code not in self.auth_codes:
211213
raise ValueError("Invalid authorization code")
214+
if not client.client_id:
215+
raise ValueError("No client_id provided")
212216

213217
# Generate MCP access token
214218
mcp_token = f"mcp_{secrets.token_hex(32)}"

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"uvicorn>=0.31.1; sys_platform != 'emscripten'",
3434
"jsonschema>=4.20.0",
3535
"pywin32>=310; sys_platform == 'win32'",
36+
"pyjwt[crypto]>=2.10.1",
3637
]
3738

3839
[project.optional-dependencies]
@@ -98,7 +99,7 @@ venv = ".venv"
9899
# those private functions instead of testing the private functions directly. It makes it easier to maintain the code source
99100
# and refactor code that is not public.
100101
executionEnvironments = [
101-
{ root = "tests", reportUnusedFunction = false, reportPrivateUsage = false },
102+
{ root = "tests", extraPaths = ["."], reportUnusedFunction = false, reportPrivateUsage = false },
102103
{ root = "examples/servers", reportUnusedFunction = false },
103104
]
104105

src/mcp/cli/cli.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _build_uv_command(
6767
with_editable: Path | None = None,
6868
with_packages: list[str] | None = None,
6969
) -> list[str]:
70-
"""Build the uv run command that runs a MCP server through mcp run."""
70+
"""Build the uv run command that runs an MCP server through mcp run."""
7171
cmd = ["uv"]
7272

7373
cmd.extend(["run", "--with", "mcp"])
@@ -117,7 +117,7 @@ def _parse_file_path(file_spec: str) -> tuple[Path, str | None]:
117117

118118

119119
def _import_server(file: Path, server_object: str | None = None):
120-
"""Import a MCP server from a file.
120+
"""Import an MCP server from a file.
121121
122122
Args:
123123
file: Path to the file
@@ -244,7 +244,7 @@ def dev(
244244
),
245245
] = [],
246246
) -> None:
247-
"""Run a MCP server with the MCP Inspector."""
247+
"""Run an MCP server with the MCP Inspector."""
248248
file, server_object = _parse_file_path(file_spec)
249249

250250
logger.debug(
@@ -317,7 +317,7 @@ def run(
317317
),
318318
] = None,
319319
) -> None:
320-
"""Run a MCP server.
320+
"""Run an MCP server.
321321
322322
The server can be specified in two ways:\n
323323
1. Module approach: server.py - runs the module directly, expecting a server.run() call.\n
@@ -412,7 +412,7 @@ def install(
412412
),
413413
] = None,
414414
) -> None:
415-
"""Install a MCP server in the Claude desktop app.
415+
"""Install an MCP server in the Claude desktop app.
416416
417417
Environment variables are preserved once added and only updated if new values
418418
are explicitly provided.

src/mcp/client/auth/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
OAuth2 Authentication implementation for HTTPX.
3+
4+
Implements authorization code flow with PKCE and automatic token refresh.
5+
"""
6+
7+
from mcp.client.auth.oauth2 import (
8+
OAuthClientProvider,
9+
OAuthFlowError,
10+
OAuthRegistrationError,
11+
OAuthTokenError,
12+
PKCEParameters,
13+
TokenStorage,
14+
)
15+
16+
__all__ = [
17+
"OAuthClientProvider",
18+
"OAuthFlowError",
19+
"OAuthRegistrationError",
20+
"OAuthTokenError",
21+
"PKCEParameters",
22+
"TokenStorage",
23+
]

src/mcp/client/auth/extensions/__init__.py

Whitespace-only changes.
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import time
2+
from collections.abc import Awaitable, Callable
3+
from typing import Any
4+
from uuid import uuid4
5+
6+
import httpx
7+
import jwt
8+
from pydantic import BaseModel, Field
9+
10+
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
11+
from mcp.shared.auth import OAuthClientMetadata
12+
13+
14+
class JWTParameters(BaseModel):
15+
"""JWT parameters."""
16+
17+
assertion: str | None = Field(
18+
default=None,
19+
description="JWT assertion for JWT authentication. "
20+
"Will be used instead of generating a new assertion if provided.",
21+
)
22+
23+
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
24+
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
25+
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
26+
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
27+
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
28+
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
29+
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
30+
31+
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
32+
if self.assertion is not None:
33+
# Prebuilt JWT (e.g. acquired out-of-band)
34+
assertion = self.assertion
35+
else:
36+
if not self.jwt_signing_key:
37+
raise OAuthFlowError("Missing signing key for JWT bearer grant")
38+
if not self.issuer:
39+
raise OAuthFlowError("Missing issuer for JWT bearer grant")
40+
if not self.subject:
41+
raise OAuthFlowError("Missing subject for JWT bearer grant")
42+
43+
audience = self.audience if self.audience else with_audience_fallback
44+
if not audience:
45+
raise OAuthFlowError("Missing audience for JWT bearer grant")
46+
47+
now = int(time.time())
48+
claims: dict[str, Any] = {
49+
"iss": self.issuer,
50+
"sub": self.subject,
51+
"aud": audience,
52+
"exp": now + self.jwt_lifetime_seconds,
53+
"iat": now,
54+
"jti": str(uuid4()),
55+
}
56+
claims.update(self.claims or {})
57+
58+
assertion = jwt.encode(
59+
claims,
60+
self.jwt_signing_key,
61+
algorithm=self.jwt_signing_algorithm or "RS256",
62+
)
63+
return assertion
64+
65+
66+
class RFC7523OAuthClientProvider(OAuthClientProvider):
67+
"""OAuth client provider for RFC7532 clients."""
68+
69+
jwt_parameters: JWTParameters | None = None
70+
71+
def __init__(
72+
self,
73+
server_url: str,
74+
client_metadata: OAuthClientMetadata,
75+
storage: TokenStorage,
76+
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
77+
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
78+
timeout: float = 300.0,
79+
jwt_parameters: JWTParameters | None = None,
80+
) -> None:
81+
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
82+
self.jwt_parameters = jwt_parameters
83+
84+
async def _exchange_token_authorization_code(
85+
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
86+
) -> httpx.Request:
87+
"""Build token exchange request for authorization_code flow."""
88+
token_data = token_data or {}
89+
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
90+
self._add_client_authentication_jwt(token_data=token_data)
91+
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
92+
93+
async def _perform_authorization(self) -> httpx.Request:
94+
"""Perform the authorization flow."""
95+
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
96+
token_request = await self._exchange_token_jwt_bearer()
97+
return token_request
98+
else:
99+
return await super()._perform_authorization()
100+
101+
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]):
102+
"""Add JWT assertion for client authentication to token endpoint parameters."""
103+
if not self.jwt_parameters:
104+
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
105+
if not self.context.oauth_metadata:
106+
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
107+
108+
# We need to set the audience to the issuer identifier of the authorization server
109+
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
110+
issuer = str(self.context.oauth_metadata.issuer)
111+
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
112+
113+
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
114+
token_data["client_assertion"] = assertion
115+
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
116+
# We need to set the audience to the resource server, the audience is difference from the one in claims
117+
# it represents the resource server that will validate the token
118+
token_data["audience"] = self.context.get_resource_url()
119+
120+
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
121+
"""Build token exchange request for JWT bearer grant."""
122+
if not self.context.client_info:
123+
raise OAuthFlowError("Missing client info")
124+
if not self.jwt_parameters:
125+
raise OAuthFlowError("Missing JWT parameters")
126+
if not self.context.oauth_metadata:
127+
raise OAuthTokenError("Missing OAuth metadata")
128+
129+
# We need to set the audience to the issuer identifier of the authorization server
130+
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
131+
issuer = str(self.context.oauth_metadata.issuer)
132+
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
133+
134+
token_data = {
135+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
136+
"assertion": assertion,
137+
}
138+
139+
if self.context.should_include_resource_param(self.context.protocol_version):
140+
token_data["resource"] = self.context.get_resource_url()
141+
142+
if self.context.client_metadata.scope:
143+
token_data["scope"] = self.context.client_metadata.scope
144+
145+
token_url = self._get_token_endpoint()
146+
return httpx.Request(
147+
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
148+
)

0 commit comments

Comments
 (0)