Skip to content

Commit 1574814

Browse files
committed
refactor: extract common context injection pattern
Eliminated code duplication across tools, resources, and prompts by creating a centralized context injection utility module. Changes: - Add context_injection.py utility with find_context_parameter() and inject_context() - Refactor tool, resource, and prompt implementations to use the shared utility - Remove duplicated context detection logic from all three implementations - Maintain backward compatibility with all existing tests passing This refactoring reduces code duplication and makes the context injection pattern consistent and maintainable across all FastMCP handler types.
1 parent 6fd562d commit 1574814

File tree

5 files changed

+129
-48
lines changed

5 files changed

+129
-48
lines changed

src/mcp/server/fastmcp/prompts/base.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
import inspect
66
from collections.abc import Awaitable, Callable, Sequence
7-
from typing import TYPE_CHECKING, Any, Literal, get_origin
7+
from typing import TYPE_CHECKING, Any, Literal
88

99
import pydantic_core
1010
from pydantic import BaseModel, Field, TypeAdapter, validate_call
1111

12+
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
1213
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
1314
from mcp.types import ContentBlock, TextContent
1415

@@ -96,20 +97,7 @@ def from_function(
9697

9798
# Find context parameter if it exists
9899
if context_kwarg is None:
99-
from mcp.server.fastmcp.server import Context
100-
101-
sig = inspect.signature(fn)
102-
for param_name, param in sig.parameters.items():
103-
if get_origin(param.annotation) is not None:
104-
continue
105-
if param.annotation is not inspect.Parameter.empty:
106-
try:
107-
if issubclass(param.annotation, Context):
108-
context_kwarg = param_name
109-
break
110-
except TypeError:
111-
# issubclass raises TypeError for non-class types
112-
pass
100+
context_kwarg = find_context_parameter(fn)
113101

114102
# Get schema from func_metadata, excluding context parameter
115103
func_arg_metadata = func_metadata(
@@ -159,9 +147,7 @@ async def render(
159147

160148
try:
161149
# Add context to arguments if needed
162-
call_args = arguments or {}
163-
if self.context_kwarg is not None and context is not None:
164-
call_args = {**call_args, self.context_kwarg: context}
150+
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)
165151

166152
# Call function and check if result is a coroutine
167153
result = self.fn(**call_args)

src/mcp/server/fastmcp/resources/templates.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import inspect
66
import re
77
from collections.abc import Callable
8-
from typing import TYPE_CHECKING, Any, get_origin
8+
from typing import TYPE_CHECKING, Any
99

1010
from pydantic import BaseModel, Field, validate_call
1111

1212
from mcp.server.fastmcp.resources.types import FunctionResource, Resource
13+
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
1314
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
1415

1516
if TYPE_CHECKING:
@@ -48,20 +49,7 @@ def from_function(
4849

4950
# Find context parameter if it exists
5051
if context_kwarg is None:
51-
from mcp.server.fastmcp.server import Context
52-
53-
sig = inspect.signature(fn)
54-
for param_name, param in sig.parameters.items():
55-
if get_origin(param.annotation) is not None:
56-
continue
57-
if param.annotation is not inspect.Parameter.empty:
58-
try:
59-
if issubclass(param.annotation, Context):
60-
context_kwarg = param_name
61-
break
62-
except TypeError:
63-
# issubclass raises TypeError for non-class types
64-
pass
52+
context_kwarg = find_context_parameter(fn)
6553

6654
# Get schema from func_metadata, excluding context parameter
6755
func_arg_metadata = func_metadata(
@@ -102,8 +90,7 @@ async def create_resource(
10290
"""Create a resource from the template with the given parameters."""
10391
try:
10492
# Add context to params if needed
105-
if self.context_kwarg is not None and context is not None:
106-
params = {**params, self.context_kwarg: context}
93+
params = inject_context(self.fn, params, context, self.context_kwarg)
10794

10895
# Call function and check if result is a coroutine
10996
result = self.fn(**params)

src/mcp/server/fastmcp/tools/base.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import inspect
55
from collections.abc import Callable
66
from functools import cached_property
7-
from typing import TYPE_CHECKING, Any, get_origin
7+
from typing import TYPE_CHECKING, Any
88

99
from pydantic import BaseModel, Field
1010

1111
from mcp.server.fastmcp.exceptions import ToolError
12+
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
1213
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
1314
from mcp.types import ToolAnnotations
1415

@@ -49,8 +50,6 @@ def from_function(
4950
structured_output: bool | None = None,
5051
) -> Tool:
5152
"""Create a Tool from a function."""
52-
from mcp.server.fastmcp.server import Context
53-
5453
func_name = name or fn.__name__
5554

5655
if func_name == "<lambda>":
@@ -60,13 +59,7 @@ def from_function(
6059
is_async = _is_async_callable(fn)
6160

6261
if context_kwarg is None:
63-
sig = inspect.signature(fn)
64-
for param_name, param in sig.parameters.items():
65-
if get_origin(param.annotation) is not None:
66-
continue
67-
if issubclass(param.annotation, Context):
68-
context_kwarg = param_name
69-
break
62+
context_kwarg = find_context_parameter(fn)
7063

7164
func_arg_metadata = func_metadata(
7265
fn,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Context injection utilities for FastMCP."""
2+
3+
from __future__ import annotations
4+
5+
import inspect
6+
from collections.abc import Callable
7+
from typing import Any, get_origin
8+
9+
10+
def find_context_parameter(fn: Callable[..., Any]) -> str | None:
11+
"""Find the parameter that should receive the Context object.
12+
13+
Searches through the function's signature to find a parameter
14+
with a Context type annotation.
15+
16+
Args:
17+
fn: The function to inspect
18+
19+
Returns:
20+
The name of the context parameter, or None if not found
21+
"""
22+
from mcp.server.fastmcp.server import Context
23+
24+
sig = inspect.signature(fn)
25+
for param_name, param in sig.parameters.items():
26+
# Skip generic types
27+
if get_origin(param.annotation) is not None:
28+
continue
29+
30+
# Check if parameter has annotation
31+
if param.annotation is not inspect.Parameter.empty:
32+
try:
33+
# Check if it's a Context subclass
34+
if issubclass(param.annotation, Context):
35+
return param_name
36+
except TypeError:
37+
# issubclass raises TypeError for non-class types
38+
pass
39+
40+
return None
41+
42+
43+
def inject_context(
44+
fn: Callable[..., Any],
45+
kwargs: dict[str, Any],
46+
context: Any | None,
47+
context_kwarg: str | None,
48+
) -> dict[str, Any]:
49+
"""Inject context into function kwargs if needed.
50+
51+
Args:
52+
fn: The function that will be called
53+
kwargs: The current keyword arguments
54+
context: The context object to inject (if any)
55+
context_kwarg: The name of the parameter to inject into
56+
57+
Returns:
58+
Updated kwargs with context injected if applicable
59+
"""
60+
if context_kwarg is not None and context is not None:
61+
return {**kwargs, context_kwarg: context}
62+
return kwargs

0 commit comments

Comments
 (0)