1313import time
1414from collections .abc import AsyncGenerator , Awaitable , Callable
1515from dataclasses import dataclass , field
16- from typing import Protocol
16+ from typing import Any , Protocol
1717from urllib .parse import urlencode , urljoin , urlparse
1818
1919import anyio
@@ -88,8 +88,8 @@ class OAuthContext:
8888 server_url : str
8989 client_metadata : OAuthClientMetadata
9090 storage : TokenStorage
91- redirect_handler : Callable [[str ], Awaitable [None ]]
92- callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]]
91+ redirect_handler : Callable [[str ], Awaitable [None ]] | None
92+ callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None
9393 timeout : float = 300.0
9494
9595 # Discovered metadata
@@ -194,8 +194,8 @@ def __init__(
194194 server_url : str ,
195195 client_metadata : OAuthClientMetadata ,
196196 storage : TokenStorage ,
197- redirect_handler : Callable [[str ], Awaitable [None ]],
198- callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]],
197+ redirect_handler : Callable [[str ], Awaitable [None ]] | None = None ,
198+ callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None = None ,
199199 timeout : float = 300.0 ,
200200 ):
201201 """Initialize OAuth2 authentication."""
@@ -423,8 +423,21 @@ async def _handle_registration_response(self, response: httpx.Response) -> None:
423423 except ValidationError as e :
424424 raise OAuthRegistrationError (f"Invalid registration response: { e } " )
425425
426- async def _perform_authorization (self ) -> tuple [str , str ]:
426+ async def _perform_authorization (self ) -> httpx .Request :
427+ """Perform the authorization flow."""
428+ auth_code , code_verifier = await self ._perform_authorization_code_grant ()
429+ token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier )
430+ return token_request
431+
432+ async def _perform_authorization_code_grant (self ) -> tuple [str , str ]:
427433 """Perform the authorization redirect and get auth code."""
434+ if self .context .client_metadata .redirect_uris is None :
435+ raise OAuthFlowError ("No redirect URIs provided for authorization code grant" )
436+ if not self .context .redirect_handler :
437+ raise OAuthFlowError ("No redirect handler provided for authorization code grant" )
438+ if not self .context .callback_handler :
439+ raise OAuthFlowError ("No callback handler provided for authorization code grant" )
440+
428441 if self .context .oauth_metadata and self .context .oauth_metadata .authorization_endpoint :
429442 auth_endpoint = str (self .context .oauth_metadata .authorization_endpoint )
430443 else :
@@ -469,24 +482,34 @@ async def _perform_authorization(self) -> tuple[str, str]:
469482 # Return auth code and code verifier for token exchange
470483 return auth_code , pkce_params .code_verifier
471484
472- async def _exchange_token (self , auth_code : str , code_verifier : str ) -> httpx .Request :
473- """Build token exchange request."""
474- if not self .context .client_info :
475- raise OAuthFlowError ("Missing client info" )
476-
485+ def _get_token_endpoint (self ) -> str :
477486 if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint :
478487 token_url = str (self .context .oauth_metadata .token_endpoint )
479488 else :
480489 auth_base_url = self .context .get_authorization_base_url (self .context .server_url )
481490 token_url = urljoin (auth_base_url , "/token" )
491+ return token_url
492+
493+ async def _exchange_token_authorization_code (
494+ self , auth_code : str , code_verifier : str , * , token_data : dict [str , Any ] | None = {}
495+ ) -> httpx .Request :
496+ """Build token exchange request for authorization_code flow."""
497+ if self .context .client_metadata .redirect_uris is None :
498+ raise OAuthFlowError ("No redirect URIs provided for authorization code grant" )
499+ if not self .context .client_info :
500+ raise OAuthFlowError ("Missing client info" )
482501
483- token_data = {
484- "grant_type" : "authorization_code" ,
485- "code" : auth_code ,
486- "redirect_uri" : str (self .context .client_metadata .redirect_uris [0 ]),
487- "client_id" : self .context .client_info .client_id ,
488- "code_verifier" : code_verifier ,
489- }
502+ token_url = self ._get_token_endpoint ()
503+ token_data = token_data or {}
504+ token_data .update (
505+ {
506+ "grant_type" : "authorization_code" ,
507+ "code" : auth_code ,
508+ "redirect_uri" : str (self .context .client_metadata .redirect_uris [0 ]),
509+ "client_id" : self .context .client_info .client_id ,
510+ "code_verifier" : code_verifier ,
511+ }
512+ )
490513
491514 # Only include resource param if conditions are met
492515 if self .context .should_include_resource_param (self .context .protocol_version ):
@@ -502,7 +525,9 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req
502525 async def _handle_token_response (self , response : httpx .Response ) -> None :
503526 """Handle token exchange response."""
504527 if response .status_code != 200 :
505- raise OAuthTokenError (f"Token exchange failed: { response .status_code } " )
528+ body = await response .aread ()
529+ body = body .decode ("utf-8" )
530+ raise OAuthTokenError (f"Token exchange failed ({ response .status_code } ): { body } " )
506531
507532 try :
508533 content = await response .aread ()
@@ -663,12 +688,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
663688 registration_response = yield registration_request
664689 await self ._handle_registration_response (registration_response )
665690
666- # Step 5: Perform authorization
667- auth_code , code_verifier = await self ._perform_authorization ()
668-
669- # Step 6: Exchange authorization code for tokens
670- token_request = await self ._exchange_token (auth_code , code_verifier )
671- token_response = yield token_request
691+ # Step 5: Perform authorization and complete token exchange
692+ token_response = yield await self ._perform_authorization ()
672693 await self ._handle_token_response (token_response )
673694 except Exception :
674695 logger .exception ("OAuth flow error" )
@@ -687,17 +708,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
687708 # Step 2a: Update the required scopes
688709 self ._select_scopes (response )
689710
690- # Step 2b: Perform (re-)authorization
691- auth_code , code_verifier = await self ._perform_authorization ()
692-
693- # Step 2c: Exchange authorization code for tokens
694- token_request = await self ._exchange_token (auth_code , code_verifier )
695- token_response = yield token_request
711+ # Step 2b: Perform (re-)authorization and token exchange
712+ token_response = yield await self ._perform_authorization ()
696713 await self ._handle_token_response (token_response )
697714 except Exception :
698715 logger .exception ("OAuth flow error" )
699716 raise
700717
701- # Retry with new tokens
702- self ._add_auth_header (request )
703- yield request
718+ # Retry with new tokens
719+ self ._add_auth_header (request )
720+ yield request
0 commit comments