diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 49d289fb7..78ac252c0 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -550,16 +550,23 @@ async def handler(req: types.CallToolRequest): def progress_notification(self): def decorator( - func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], + func: Callable[ + [str | int, float, float | None, str | None, ServerSession | None], + Awaitable[None], + ], ): logger.debug("Registering handler for ProgressNotification") - async def handler(req: types.ProgressNotification): + async def handler( + req: types.ProgressNotification, + session: ServerSession | None = None, + ): await func( req.params.progressToken, req.params.progress, req.params.total, req.params.message, + session, ) self.notification_handlers[types.ProgressNotification] = handler @@ -567,6 +574,36 @@ async def handler(req: types.ProgressNotification): return decorator + def initialized_notification(self): + """Decorator to register a handler for InitializedNotification.""" + + def decorator( + func: Callable[ + [types.InitializedNotification, ServerSession | None], + Awaitable[None], + ], + ): + logger.debug("Registering handler for InitializedNotification") + self.notification_handlers[types.InitializedNotification] = func + return func + + return decorator + + def roots_list_changed_notification(self): + """Decorator to register a handler for RootsListChangedNotification.""" + + def decorator( + func: Callable[ + [types.RootsListChangedNotification, ServerSession | None], + Awaitable[None], + ], + ): + logger.debug("Registering handler for RootsListChangedNotification") + self.notification_handlers[types.RootsListChangedNotification] = func + return func + + return decorator + def completion(self): """Provides completions for prompts and resource templates""" @@ -638,7 +675,7 @@ async def run( async def _handle_message( self, - message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + message: (RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception), session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, @@ -649,7 +686,7 @@ async def _handle_message( with responder: await self._handle_request(message, req, session, lifespan_context, raise_exceptions) case types.ClientNotification(root=notify): - await self._handle_notification(notify) + await self._handle_notification(notify, session) case Exception(): # pragma: no cover logger.error(f"Received exception from stream: {message}") await session.send_log_message( @@ -660,8 +697,8 @@ async def _handle_message( if raise_exceptions: raise message - for warning in w: # pragma: no cover - logger.info("Warning: %s: %s", warning.category.__name__, warning.message) + for warning in w: # pragma: no cover + logger.info("Warning: %s: %s", warning.category.__name__, warning.message) async def _handle_request( self, @@ -724,12 +761,15 @@ async def _handle_request( logger.debug("Response sent") - async def _handle_notification(self, notify: Any): + async def _handle_notification(self, notify: Any, session: ServerSession): if handler := self.notification_handlers.get(type(notify)): # type: ignore logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - await handler(notify) + try: + await handler(notify, session) + except TypeError: + await handler(notify) except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_notifications.py similarity index 78% rename from tests/shared/test_progress_notifications.py rename to tests/shared/test_notifications.py index 1552711d2..8e617a6be 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_notifications.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, cast from unittest.mock import patch @@ -12,8 +13,9 @@ from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.shared.memory import create_connected_server_and_client_session +from mcp.shared.message import SessionMessage from mcp.shared.progress import progress -from mcp.shared.session import BaseSession, RequestResponder, SessionMessage +from mcp.shared.session import BaseSession, RequestResponder @pytest.mark.anyio @@ -23,6 +25,8 @@ async def test_bidirectional_progress_notifications(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) + server_session_ref: list[ServerSession | None] = [None] + # Run a server session so we can send progress updates in tool async def run_server(): # Create a server session @@ -35,9 +39,7 @@ async def run_server(): capabilities=server.get_capabilities(NotificationOptions(), {}), ), ) as server_session: - global serv_sesh - - serv_sesh = server_session + server_session_ref[0] = server_session async for message in server_session.incoming_messages: try: await server._handle_message(message, server_session, {}) @@ -62,6 +64,7 @@ async def handle_progress( progress: float, total: float | None, message: str | None, + session: ServerSession | None, ): server_progress_updates.append( { @@ -86,6 +89,10 @@ async def handle_list_tools() -> list[types.Tool]: # Register tool handler @server.call_tool() async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + serv_sesh = server_session_ref[0] + if not serv_sesh: + raise ValueError("Server session not available") + # Make sure we received a progress token if name == "test_tool": if arguments and "_meta" in arguments: @@ -228,6 +235,7 @@ async def handle_progress( progress: float, total: float | None, message: str | None, + session: ServerSession | None, ): server_progress_updates.append( {"token": progress_token, "progress": progress, "total": total, "message": message} @@ -390,3 +398,106 @@ async def handle_list_tools() -> list[types.Tool]: # Check that a warning was logged for the progress callback exception assert len(logged_errors) > 0 assert any("Progress callback raised an exception" in warning for warning in logged_errors) + + +@pytest.mark.anyio +async def test_initialized_notification(): + """Test that the server receives and handles InitializedNotification.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + initialized_received = asyncio.Event() + received_session: ServerSession | None = None + + @server.initialized_notification() + async def handle_initialized( + notification: types.InitializedNotification, + session: ServerSession | None = None, + ): + nonlocal received_session + received_session = session + initialized_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await initialized_received.wait() + tg.cancel_scope.cancel() + + assert initialized_received.is_set() + assert isinstance(received_session, ServerSession) + + +@pytest.mark.anyio +async def test_roots_list_changed_notification(): + """Test that the server receives and handles RootsListChangedNotification.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + roots_list_changed_received = asyncio.Event() + received_session: ServerSession | None = None + + @server.roots_list_changed_notification() + async def handle_roots_list_changed( + notification: types.RootsListChangedNotification, + session: ServerSession | None = None, + ): + nonlocal received_session + received_session = session + roots_list_changed_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await client_session.send_notification( + types.ClientNotification( + root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None) + ) + ) + await roots_list_changed_received.wait() + tg.cancel_scope.cancel() + + assert roots_list_changed_received.is_set() + assert isinstance(received_session, ServerSession)