3
3
from __future__ import annotations
4
4
5
5
import inspect
6
+ import typing
6
7
from collections .abc import Callable
7
- from typing import Any , get_origin
8
+ from typing import Any
8
9
9
10
10
11
def find_context_parameter (fn : Callable [..., Any ]) -> str | None :
@@ -21,21 +22,26 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None:
21
22
"""
22
23
from mcp .server .fastmcp .server import Context
23
24
24
- sig = inspect .signature (fn )
25
- for param_name , param in sig .parameters .items ():
26
- # Skip generic types
27
- if get_origin (param .annotation ) is not None :
28
- continue
29
-
30
- # Check if parameter has annotation
31
- if param .annotation is not inspect .Parameter .empty :
32
- try :
33
- # Check if it's a Context subclass
34
- if issubclass (param .annotation , Context ):
25
+ # Get type hints to properly resolve string annotations
26
+ try :
27
+ hints = typing .get_type_hints (fn )
28
+ except Exception :
29
+ # If we can't resolve type hints, we can't find the context parameter
30
+ return None
31
+
32
+ # Check each parameter's type hint
33
+ for param_name , annotation in hints .items ():
34
+ # Handle direct Context type
35
+ if inspect .isclass (annotation ) and issubclass (annotation , Context ):
36
+ return param_name
37
+
38
+ # Handle generic types like Optional[Context]
39
+ origin = typing .get_origin (annotation )
40
+ if origin is not None :
41
+ args = typing .get_args (annotation )
42
+ for arg in args :
43
+ if inspect .isclass (arg ) and issubclass (arg , Context ):
35
44
return param_name
36
- except TypeError :
37
- # issubclass raises TypeError for non-class types
38
- pass
39
45
40
46
return None
41
47
0 commit comments