diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index d86fa85e3..c603388a2 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -84,6 +84,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]): json_response: bool stateless_http: bool """Define if the server should create a new transport per request.""" + single_tenant: bool + """Define if the server should use only one transport for all requests.""" # resource settings warn_on_duplicate_resources: bool @@ -139,6 +141,7 @@ def __init__( streamable_http_path: str = "/mcp", json_response: bool = False, stateless_http: bool = False, + single_tenant: bool = False, warn_on_duplicate_resources: bool = True, warn_on_duplicate_tools: bool = True, warn_on_duplicate_prompts: bool = True, @@ -158,6 +161,7 @@ def __init__( streamable_http_path=streamable_http_path, json_response=json_response, stateless_http=stateless_http, + single_tenant=single_tenant, warn_on_duplicate_resources=warn_on_duplicate_resources, warn_on_duplicate_tools=warn_on_duplicate_tools, warn_on_duplicate_prompts=warn_on_duplicate_prompts, @@ -868,6 +872,7 @@ def streamable_http_app(self) -> Starlette: event_store=self._event_store, json_response=self.settings.json_response, stateless=self.settings.stateless_http, # Use the stateless setting + single_tenant=self.settings.single_tenant, security_settings=self.settings.transport_security, ) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 53d542d21..1828500c3 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -51,6 +51,9 @@ class StreamableHTTPSessionManager: json_response: Whether to use JSON responses instead of SSE streams stateless: If True, creates a completely fresh transport for each request with no session tracking or state persistence between requests. + single_tenant: If True, only one transport will be created to process the entire + every request, regardless of the MCP session id. This is useful for + hosting platforms where the MCP server is launched in a single-tenant box. """ def __init__( @@ -59,12 +62,18 @@ def __init__( event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, + single_tenant: bool = False, security_settings: TransportSecuritySettings | None = None, ): self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless + self.single_tenant = single_tenant + if self.stateless and self.single_tenant: + # A single-tenant server must be stateful, but stateful server does not + # have to be single tenant. + raise ValueError("A single-tenant server must stateful.") self.security_settings = security_settings # Session tracking (only used if not stateless) @@ -209,6 +218,19 @@ async def _handle_stateful_request( request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + if self.single_tenant and self._server_instances: + # being single_tenant means that there is only one ASGI server for the entire application. + # that server is used to process all the mcp requests because the hosting platform + # is already distributing the request to the box where the single-tenant mcp server runs. + assert len(self._server_instances) == 1 + # hosting platforms might exposes a different mcp session ID hence + # we take the first key of server instances as the mcp session id + request_mcp_session_id = next(iter(self._server_instances.keys())) + headers = dict(scope["headers"]) + # Also need to reset the incoming request mcp session id to this existing session id + headers[MCP_SESSION_ID_HEADER.encode("latin-1")] = request_mcp_session_id.encode("latin-1") + scope["headers"] = list(headers.items()) + # Existing session case if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index 5e34ba1b1..7da616100 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -1283,21 +1283,40 @@ def prompt_fn(name: str) -> str: with pytest.raises(McpError, match="Missing required arguments"): await client.get_prompt("prompt_fn") + def test_single_tenant_default_false(self): + """Test that single_tenant defaults to False.""" + mcp = FastMCP("test") + assert mcp.settings.single_tenant is False + + def test_single_tenant_can_be_set_true(self): + """Test that single_tenant can be set to True.""" + mcp = FastMCP("test", single_tenant=True) + assert mcp.settings.single_tenant is True + + def test_single_tenant_passed_to_session_manager(self): + """Test that single_tenant is passed to StreamableHTTPSessionManager.""" + mcp = FastMCP("test", single_tenant=True) + + # Access the session manager to trigger its creation + # We need to call streamable_http_app() first to initialize the session manager + mcp.streamable_http_app() + session_manager = mcp.session_manager + assert session_manager.single_tenant is True + + def test_streamable_http_no_redirect(self) -> None: + """Test that streamable HTTP routes are correctly configured.""" + mcp = FastMCP() + app = mcp.streamable_http_app() + + # Find routes by type - streamable_http_app creates Route objects, not Mount objects + streamable_routes = [ + r + for r in app.routes + if isinstance(r, Route) and hasattr(r, "path") and r.path == mcp.settings.streamable_http_path + ] -def test_streamable_http_no_redirect() -> None: - """Test that streamable HTTP routes are correctly configured.""" - mcp = FastMCP() - app = mcp.streamable_http_app() - - # Find routes by type - streamable_http_app creates Route objects, not Mount objects - streamable_routes = [ - r - for r in app.routes - if isinstance(r, Route) and hasattr(r, "path") and r.path == mcp.settings.streamable_http_path - ] - - # Verify routes exist - assert len(streamable_routes) == 1, "Should have one streamable route" + # Verify routes exist + assert len(streamable_routes) == 1, "Should have one streamable route" - # Verify path values - assert streamable_routes[0].path == "/mcp", "Streamable route path should be /mcp" + # Verify path values + assert streamable_routes[0].path == "/mcp", "Streamable route path should be /mcp" diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 7a8551e5c..84b43fa31 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -5,7 +5,8 @@ import anyio import pytest -from starlette.types import Message +from starlette.requests import Request +from starlette.types import Message, Scope from mcp.server import streamable_http_manager from mcp.server.lowlevel import Server @@ -13,6 +14,64 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +def test_single_tenant_validation(): + """Test that single_tenant=True with stateless=True raises ValueError.""" + app = Server("test") + + with pytest.raises(ValueError, match="A single-tenant server must stateful"): + StreamableHTTPSessionManager(app=app, single_tenant=True, stateless=True) + + +def test_single_tenant_default_false(): + """Test that single_tenant defaults to False.""" + app = Server("test") + manager = StreamableHTTPSessionManager(app=app) + assert manager.single_tenant is False + + +def test_single_tenant_can_be_set_true(): + """Test that single_tenant can be set to True.""" + app = Server("test") + manager = StreamableHTTPSessionManager(app=app, single_tenant=True) + assert manager.single_tenant is True + + +@pytest.mark.anyio +async def test_single_tenant_reuses_existing_session(): + """Test that single_tenant mode reuses existing session.""" + app = Server("test") + manager = StreamableHTTPSessionManager(app=app, single_tenant=True) + + mock_mcp_run = AsyncMock(return_value=None) + # This will be called by StreamableHTTPSessionManager's run_server -> self.app.run + app.run = mock_mcp_run + + # Manually add a session to simulate existing session + existing_session_id = "existing-session-123" + mock_transport = AsyncMock() + manager._server_instances[existing_session_id] = mock_transport + + # Create a request with different session ID + request_mcp_session_id = "different-session-id" + scope: Scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [ + (b"content-type", b"application/json"), + (MCP_SESSION_ID_HEADER.encode("latin-1"), request_mcp_session_id.encode("latin-1")), + ], + } + + async with manager.run(): + await manager.handle_request(scope, AsyncMock(), AsyncMock()) + headers = Request(scope).headers + modified_session_id = headers[MCP_SESSION_ID_HEADER] + + assert modified_session_id == existing_session_id + assert len(manager._server_instances) == 1 + + @pytest.mark.anyio async def test_run_can_only_be_called_once(): """Test that run() can only be called once per instance."""