Skip to content

Commit ce597dd

Browse files
committed
s/while/for/, delete instance variables
1 parent eeb74bd commit ce597dd

File tree

2 files changed

+18
-60
lines changed

2 files changed

+18
-60
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,6 @@ class OAuthContext:
108108
# State
109109
lock: anyio.Lock = field(default_factory=anyio.Lock)
110110

111-
# Discovery state for fallback support (SEP-985)
112-
discovery_urls: list[str] = field(default_factory=lambda: [])
113-
discovery_index: int = 0
114-
115111
def get_authorization_base_url(self, server_url: str) -> str:
116112
"""Extract base URL by removing path component."""
117113
parsed = urlparse(server_url)
@@ -141,11 +137,6 @@ def clear_tokens(self) -> None:
141137
self.current_tokens = None
142138
self.token_expiry_time = None
143139

144-
def reset_discovery_state(self) -> None:
145-
"""Reset protected resource metadata discovery state."""
146-
self.discovery_urls = []
147-
self.discovery_index = 0
148-
149140
def get_resource_url(self) -> str:
150141
"""Get resource URL for RFC 8707.
151142
@@ -288,32 +279,6 @@ def _extract_scope_from_www_auth(self, init_response: httpx.Response) -> str | N
288279
"""
289280
return self._extract_field_from_www_auth(init_response, "scope")
290281

291-
async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request:
292-
"""
293-
Build protected resource metadata discovery request.
294-
295-
Per SEP-985, supports multiple discovery mechanisms with fallback:
296-
1. WWW-Authenticate header (if present)
297-
2. Path-based well-known URI
298-
3. Root-based well-known URI
299-
300-
Returns:
301-
Request for the next discovery URL to try
302-
"""
303-
# Initialize discovery URLs on first call
304-
if not self.context.discovery_urls:
305-
self.context.discovery_urls = self._build_protected_resource_discovery_urls(init_response)
306-
self.context.discovery_index = 0
307-
308-
# Get current URL to try
309-
if self.context.discovery_index < len(self.context.discovery_urls):
310-
url = self.context.discovery_urls[self.context.discovery_index]
311-
else:
312-
# No more URLs to try - this shouldn't happen in normal flow
313-
raise OAuthFlowError("Protected resource metadata discovery failed: all URLs exhausted")
314-
315-
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
316-
317282
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
318283
"""
319284
Handle protected resource metadata discovery response.
@@ -644,25 +609,21 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
644609
if response.status_code == 401:
645610
# Perform full OAuth flow
646611
try:
647-
# Reset discovery state for new OAuth flow
648-
self.context.reset_discovery_state()
649-
650612
# OAuth flow must be inline due to generator constraints
651613
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
652-
# Try discovery URLs in order until one succeeds
614+
discovery_urls = self._build_protected_resource_discovery_urls(response)
653615
discovery_success = False
654-
while not discovery_success:
655-
discovery_request = await self._discover_protected_resource(response)
616+
for url in discovery_urls:
617+
discovery_request = httpx.Request(
618+
"GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
619+
)
656620
discovery_response = yield discovery_request
657621
discovery_success = await self._handle_protected_resource_response(discovery_response)
622+
if discovery_success:
623+
break
658624

659-
if not discovery_success:
660-
# Try next URL in fallback chain
661-
self.context.discovery_index += 1
662-
if self.context.discovery_index >= len(self.context.discovery_urls):
663-
raise OAuthFlowError(
664-
"Protected resource metadata discovery failed: no valid metadata found"
665-
)
625+
if not discovery_success:
626+
raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found")
666627

667628
# Step 2: Apply scope selection strategy
668629
self._select_scopes(response)

tests/client/test_auth.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,10 @@ class TestOAuthFlow:
241241
"""Test OAuth flow methods."""
242242

243243
@pytest.mark.anyio
244-
async def test_discover_protected_resource_request(
244+
async def test_build_protected_resource_discovery_urls(
245245
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
246246
):
247-
"""Test protected resource discovery request building maintains backward compatibility."""
247+
"""Test protected resource metadata discovery URL building with fallback."""
248248

249249
async def redirect_handler(url: str) -> None:
250250
pass
@@ -265,22 +265,19 @@ async def callback_handler() -> tuple[str, str | None]:
265265
status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com")
266266
)
267267

268-
request = await provider._discover_protected_resource(init_response)
269-
assert request.method == "GET"
270-
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
271-
assert "mcp-protocol-version" in request.headers
268+
urls = provider._build_protected_resource_discovery_urls(init_response)
269+
assert len(urls) == 1
270+
assert urls[0] == "https://api.example.com/.well-known/oauth-protected-resource"
272271

273272
# Test with WWW-Authenticate header
274-
# Reset discovery state for new test case
275-
provider.context.reset_discovery_state()
276273
init_response.headers["WWW-Authenticate"] = (
277274
'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"'
278275
)
279276

280-
request = await provider._discover_protected_resource(init_response)
281-
assert request.method == "GET"
282-
assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path"
283-
assert "mcp-protocol-version" in request.headers
277+
urls = provider._build_protected_resource_discovery_urls(init_response)
278+
assert len(urls) == 2
279+
assert urls[0] == "https://prm.example.com/.well-known/oauth-protected-resource/path"
280+
assert urls[1] == "https://api.example.com/.well-known/oauth-protected-resource"
284281

285282
@pytest.mark.anyio
286283
def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider):

0 commit comments

Comments
 (0)