Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,17 @@ async def _receive_loop(self) -> None:
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
try:
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
except Exception as e:
logging.error(
"Progress callback raised an exception: %s",
e,
)
Comment on lines +401 to +405
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an outer exception block 🤔

Copy link
Contributor

@felixweinberger felixweinberger Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair - how would you propose addressing the misleading part here instead? The callback could be anything, so it could throw any kind of exception.

We could potentially replace 410 to just be a generic logging.exception(e) without any additional decoration? In order to be more broad without adding misleading commentary?

Comment on lines +402 to +405
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging.exception instead of logging.error.

await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
Expand Down
68 changes: 68 additions & 0 deletions tests/shared/test_progress_notifications.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, cast
from unittest.mock import patch

import anyio
import pytest
Expand All @@ -10,6 +11,7 @@
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.context import RequestContext
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.shared.progress import progress
from mcp.shared.session import BaseSession, RequestResponder, SessionMessage

Expand Down Expand Up @@ -320,3 +322,69 @@ async def handle_client_message(
assert server_progress_updates[3]["progress"] == 100
assert server_progress_updates[3]["total"] == 100
assert server_progress_updates[3]["message"] == "Processing results..."


@pytest.mark.anyio
async def test_progress_callback_exception_logging():
"""Test that exceptions in progress callbacks are logged and \
don't crash the session."""
# Track logged warnings
logged_errors: list[str] = []

def mock_log_error(msg: str, *args: Any) -> None:
logged_errors.append(msg % args if args else msg)

# Create a progress callback that raises an exception
async def failing_progress_callback(progress: float, total: float | None, message: str | None) -> None:
raise ValueError("Progress callback failed!")

# Create a server with a tool that sends progress notifications
server = Server(name="TestProgressServer")

@server.call_tool()
async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent]:
if name == "progress_tool":
# Send a progress notification
await server.request_context.session.send_progress_notification(
progress_token=server.request_context.request_id,
progress=50.0,
total=100.0,
message="Halfway done",
)
return [types.TextContent(type="text", text="progress_result")]
raise ValueError(f"Unknown tool: {name}")

@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [
types.Tool(
name="progress_tool",
description="A tool that sends progress notifications",
inputSchema={},
)
]

# Test with mocked logging
with patch("mcp.shared.session.logging.error", side_effect=mock_log_error):
async with create_connected_server_and_client_session(server) as client_session:
# Send a request with a failing progress callback
result = await client_session.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(name="progress_tool", arguments={}),
)
),
types.CallToolResult,
progress_callback=failing_progress_callback,
)

# Verify the request completed successfully despite the callback failure
assert len(result.content) == 1
content = result.content[0]
assert isinstance(content, types.TextContent)
assert content.text == "progress_result"

# Check that a warning was logged for the progress callback exception
assert len(logged_errors) > 0
assert any("Progress callback raised an exception" in warning for warning in logged_errors)
Loading