Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,23 +550,60 @@ async def handler(req: types.CallToolRequest):

def progress_notification(self):
def decorator(
func: Callable[[str | int, float, float | None, str | None], Awaitable[None]],
func: Callable[
[str | int, float, float | None, str | None, ServerSession | None],
Awaitable[None],
],
):
logger.debug("Registering handler for ProgressNotification")

async def handler(req: types.ProgressNotification):
async def handler(
req: types.ProgressNotification,
session: ServerSession | None = None,
):
await func(
req.params.progressToken,
req.params.progress,
req.params.total,
req.params.message,
session,
)

self.notification_handlers[types.ProgressNotification] = handler
return func

return decorator

def initialized_notification(self):
"""Decorator to register a handler for InitializedNotification."""

def decorator(
func: Callable[
[types.InitializedNotification, ServerSession | None],
Awaitable[None],
],
):
logger.debug("Registering handler for InitializedNotification")
self.notification_handlers[types.InitializedNotification] = func
return func

return decorator

def roots_list_changed_notification(self):
"""Decorator to register a handler for RootsListChangedNotification."""

def decorator(
func: Callable[
[types.RootsListChangedNotification, ServerSession | None],
Awaitable[None],
],
):
logger.debug("Registering handler for RootsListChangedNotification")
self.notification_handlers[types.RootsListChangedNotification] = func
return func

return decorator

def completion(self):
"""Provides completions for prompts and resource templates"""

Expand Down Expand Up @@ -638,7 +675,7 @@ async def run(

async def _handle_message(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
message: (RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception),
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
Expand All @@ -649,7 +686,7 @@ async def _handle_message(
with responder:
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
await self._handle_notification(notify, session)
case Exception(): # pragma: no cover
logger.error(f"Received exception from stream: {message}")
await session.send_log_message(
Expand All @@ -660,8 +697,8 @@ async def _handle_message(
if raise_exceptions:
raise message

for warning in w: # pragma: no cover
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
for warning in w: # pragma: no cover
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)

async def _handle_request(
self,
Expand Down Expand Up @@ -724,12 +761,15 @@ async def _handle_request(

logger.debug("Response sent")

async def _handle_notification(self, notify: Any):
async def _handle_notification(self, notify: Any, session: ServerSession):
if handler := self.notification_handlers.get(type(notify)): # type: ignore
logger.debug("Dispatching notification of type %s", type(notify).__name__)

try:
await handler(notify)
try:
await handler(notify, session)
except TypeError:
await handler(notify)
except Exception: # pragma: no cover
logger.exception("Uncaught exception in notification handler")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, cast
from unittest.mock import patch

Expand All @@ -12,8 +13,9 @@
from mcp.server.session import ServerSession
from mcp.shared.context import RequestContext
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.shared.message import SessionMessage
from mcp.shared.progress import progress
from mcp.shared.session import BaseSession, RequestResponder, SessionMessage
from mcp.shared.session import BaseSession, RequestResponder


@pytest.mark.anyio
Expand All @@ -23,6 +25,8 @@ async def test_bidirectional_progress_notifications():
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5)

server_session_ref: list[ServerSession | None] = [None]

# Run a server session so we can send progress updates in tool
async def run_server():
# Create a server session
Expand All @@ -35,9 +39,7 @@ async def run_server():
capabilities=server.get_capabilities(NotificationOptions(), {}),
),
) as server_session:
global serv_sesh

serv_sesh = server_session
server_session_ref[0] = server_session
async for message in server_session.incoming_messages:
try:
await server._handle_message(message, server_session, {})
Expand All @@ -62,6 +64,7 @@ async def handle_progress(
progress: float,
total: float | None,
message: str | None,
session: ServerSession | None,
):
server_progress_updates.append(
{
Expand All @@ -86,6 +89,10 @@ async def handle_list_tools() -> list[types.Tool]:
# Register tool handler
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]:
serv_sesh = server_session_ref[0]
if not serv_sesh:
raise ValueError("Server session not available")

# Make sure we received a progress token
if name == "test_tool":
if arguments and "_meta" in arguments:
Expand Down Expand Up @@ -228,6 +235,7 @@ async def handle_progress(
progress: float,
total: float | None,
message: str | None,
session: ServerSession | None,
):
server_progress_updates.append(
{"token": progress_token, "progress": progress, "total": total, "message": message}
Expand Down Expand Up @@ -390,3 +398,106 @@ async def handle_list_tools() -> list[types.Tool]:
# Check that a warning was logged for the progress callback exception
assert len(logged_errors) > 0
assert any("Progress callback raised an exception" in warning for warning in logged_errors)


@pytest.mark.anyio
async def test_initialized_notification():
"""Test that the server receives and handles InitializedNotification."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
initialized_received = asyncio.Event()
received_session: ServerSession | None = None

@server.initialized_notification()
async def handle_initialized(
notification: types.InitializedNotification,
session: ServerSession | None = None,
):
nonlocal received_session
received_session = session
initialized_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await initialized_received.wait()
tg.cancel_scope.cancel()

assert initialized_received.is_set()
assert isinstance(received_session, ServerSession)


@pytest.mark.anyio
async def test_roots_list_changed_notification():
"""Test that the server receives and handles RootsListChangedNotification."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
roots_list_changed_received = asyncio.Event()
received_session: ServerSession | None = None

@server.roots_list_changed_notification()
async def handle_roots_list_changed(
notification: types.RootsListChangedNotification,
session: ServerSession | None = None,
):
nonlocal received_session
received_session = session
roots_list_changed_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await client_session.send_notification(
types.ClientNotification(
root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None)
)
)
await roots_list_changed_received.wait()
tg.cancel_scope.cancel()

assert roots_list_changed_received.is_set()
assert isinstance(received_session, ServerSession)
Loading