Skip to content

Commit 2a592d2

Browse files
committed
feature: add type checking for passing request object
1 parent a42e097 commit 2a592d2

File tree

6 files changed

+695
-23
lines changed

6 files changed

+695
-23
lines changed

src/mcp/server/lowlevel/func_inspection.py

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import inspect
2+
import types
3+
import warnings
24
from collections.abc import Callable
3-
from typing import Any
5+
from typing import Any, TypeVar, Union, get_args, get_origin
46

57

68
def accepts_single_positional_arg(func: Callable[..., Any]) -> bool:
@@ -47,3 +49,169 @@ def accepts_single_positional_arg(func: Callable[..., Any]) -> bool:
4749
# not the responsibility of this function to check the validity of a
4850
# callback.
4951
return True
52+
53+
54+
def get_first_parameter_type(func: Callable[..., Any]) -> Any:
55+
"""
56+
Get the type annotation of the first parameter of a function.
57+
58+
Returns None if:
59+
- The function has no parameters
60+
- The first parameter has no type annotation
61+
- The signature cannot be inspected
62+
63+
Returns the actual annotation otherwise (could be a type, Any, Union, TypeVar, etc.)
64+
"""
65+
try:
66+
sig = inspect.signature(func)
67+
except (ValueError, TypeError):
68+
return None
69+
70+
params = list(sig.parameters.values())
71+
if not params:
72+
return None
73+
74+
first_param = params[0]
75+
76+
# Skip *args and **kwargs
77+
if first_param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
78+
return None
79+
80+
annotation = first_param.annotation
81+
if annotation == inspect.Parameter.empty:
82+
return None
83+
84+
return annotation
85+
86+
87+
def type_accepts_request(param_type: Any, request_type: type) -> bool:
88+
"""
89+
Check if a parameter type annotation can accept the request type.
90+
91+
Handles:
92+
- Exact type match
93+
- Union types (checks if request_type is in the Union)
94+
- TypeVars (checks if request_type matches the bound or constraints)
95+
- Generic types (basic support)
96+
- Any (always returns True)
97+
98+
Returns False for None or incompatible types.
99+
"""
100+
if param_type is None:
101+
return False
102+
103+
# Check for Any type
104+
if param_type is Any:
105+
return True
106+
107+
# Exact match
108+
if param_type == request_type:
109+
return True
110+
111+
# Handle Union types (both typing.Union and | syntax)
112+
origin = get_origin(param_type)
113+
if origin is Union or origin is types.UnionType:
114+
args = get_args(param_type)
115+
# Check if request_type is in the Union
116+
for arg in args:
117+
if arg == request_type:
118+
return True
119+
# Recursively check each union member
120+
if type_accepts_request(arg, request_type):
121+
return True
122+
return False
123+
124+
# Handle TypeVar
125+
if isinstance(param_type, TypeVar):
126+
# Check if request_type matches the bound
127+
if param_type.__bound__ is not None:
128+
if request_type == param_type.__bound__:
129+
return True
130+
# Check if request_type is a subclass of the bound
131+
try:
132+
if issubclass(request_type, param_type.__bound__):
133+
return True
134+
except TypeError:
135+
pass
136+
137+
# Check constraints
138+
if param_type.__constraints__:
139+
for constraint in param_type.__constraints__:
140+
if request_type == constraint:
141+
return True
142+
try:
143+
if issubclass(request_type, constraint):
144+
return True
145+
except TypeError:
146+
pass
147+
148+
return False
149+
150+
# For other generic types, check if request_type matches the origin
151+
if origin is not None:
152+
# Get the base generic type (e.g., list from list[str])
153+
return request_type == origin
154+
155+
return False
156+
157+
158+
def should_pass_request(func: Callable[..., Any], request_type: type) -> tuple[bool, bool]:
159+
"""
160+
Determine if a request should be passed to the function based on parameter type inspection.
161+
162+
Returns a tuple of (should_pass_request, should_deprecate):
163+
- should_pass_request: True if the request should be passed to the function
164+
- should_deprecate: True if a deprecation warning should be issued
165+
166+
The decision logic:
167+
1. If the function has no parameters -> (False, True) - old style without params, deprecate
168+
2. If the function has parameters but can't accept positional args -> (False, False)
169+
3. If the first parameter type accepts the request type -> (True, False) - pass request, no deprecation
170+
4. If the first parameter is typed as Any -> (True, True) - pass request but deprecate (effectively untyped)
171+
5. If the first parameter is typed with something incompatible -> (False, True) - old style, deprecate
172+
6. If the first parameter is untyped but accepts positional args -> (True, True) - pass request, deprecate
173+
"""
174+
can_accept_arg = accepts_single_positional_arg(func)
175+
176+
if not can_accept_arg:
177+
# Check if it has no parameters at all (old style)
178+
try:
179+
sig = inspect.signature(func)
180+
if len(sig.parameters) == 0:
181+
# Old style handler with no parameters - don't pass request but deprecate
182+
return False, True
183+
except (ValueError, TypeError):
184+
pass
185+
# Can't accept positional arguments for other reasons
186+
return False, False
187+
188+
param_type = get_first_parameter_type(func)
189+
190+
if param_type is None:
191+
# Untyped parameter - this is the old style, pass request but deprecate
192+
return True, True
193+
194+
# Check if the parameter type can accept the request
195+
if type_accepts_request(param_type, request_type):
196+
# Check if it's Any - if so, we should deprecate
197+
if param_type is Any:
198+
return True, True
199+
# Properly typed to accept the request - pass request, no deprecation
200+
return True, False
201+
202+
# Parameter is typed with something incompatible - this is an old style handler expecting
203+
# a different signature, don't pass request, issue deprecation
204+
return False, True
205+
206+
207+
def issue_deprecation_warning(func: Callable[..., Any], request_type: type) -> None:
208+
"""
209+
Issue a deprecation warning for handlers that don't use the new request parameter style.
210+
"""
211+
func_name = getattr(func, "__name__", str(func))
212+
warnings.warn(
213+
f"Handler '{func_name}' should accept a '{request_type.__name__}' parameter. "
214+
"Support for handlers without this parameter will be removed in a future version.",
215+
DeprecationWarning,
216+
stacklevel=4,
217+
)

src/mcp/server/lowlevel/server.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def main():
8282
from typing_extensions import TypeVar
8383

8484
import mcp.types as types
85-
from mcp.server.lowlevel.func_inspection import accepts_single_positional_arg
85+
from mcp.server.lowlevel.func_inspection import issue_deprecation_warning, should_pass_request
8686
from mcp.server.lowlevel.helper_types import ReadResourceContents
8787
from mcp.server.models import InitializationOptions
8888
from mcp.server.session import ServerSession
@@ -235,7 +235,10 @@ def decorator(
235235
| Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]],
236236
):
237237
logger.debug("Registering handler for PromptListRequest")
238-
pass_request = accepts_single_positional_arg(func)
238+
pass_request, should_deprecate = should_pass_request(func, types.ListPromptsRequest)
239+
240+
if should_deprecate:
241+
issue_deprecation_warning(func, types.ListPromptsRequest)
239242

240243
if pass_request:
241244
request_func = cast(Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], func)
@@ -280,7 +283,10 @@ def decorator(
280283
| Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]],
281284
):
282285
logger.debug("Registering handler for ListResourcesRequest")
283-
pass_request = accepts_single_positional_arg(func)
286+
pass_request, should_deprecate = should_pass_request(func, types.ListResourcesRequest)
287+
288+
if should_deprecate:
289+
issue_deprecation_warning(func, types.ListResourcesRequest)
284290

285291
if pass_request:
286292
request_func = cast(Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], func)
@@ -420,7 +426,10 @@ def decorator(
420426
| Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]],
421427
):
422428
logger.debug("Registering handler for ListToolsRequest")
423-
pass_request = accepts_single_positional_arg(func)
429+
pass_request, should_deprecate = should_pass_request(func, types.ListToolsRequest)
430+
431+
if should_deprecate:
432+
issue_deprecation_warning(func, types.ListToolsRequest)
424433

425434
if pass_request:
426435
request_func = cast(Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], func)

0 commit comments

Comments
 (0)