Skip to content

Commit 9327cff

Browse files
committed
Improved Trio support
* Removed the asyncio-only parametrization of the anyio_backend except for test_ws, as `websockets` doesn't support Trio yet * Try to close async generators explicitly where possible * Changed nesting order for more predictable closing of async resources * Refactored `__aenter__` and `__aexit__` in some cases to exit the task group if there's a problem during initialization * Fixed test failures in client/test_auth.py where an async fixture was used in sync tests * Fixed subtle bug in `SimpleEventStore` where retrieving the stream ID was timing-dependent
1 parent 0bcecff commit 9327cff

File tree

10 files changed

+83
-58
lines changed

10 files changed

+83
-58
lines changed

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"python-multipart>=0.0.9",
3131
"sse-starlette>=1.6.1",
3232
"pydantic-settings>=2.5.2",
33+
"typing_extensions>=4.12",
3334
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
3435
]
3536

@@ -48,10 +49,10 @@ required-version = ">=0.7.2"
4849

4950
[dependency-groups]
5051
dev = [
52+
"anyio[trio]",
5153
"pyright>=1.1.391",
5254
"pytest>=8.3.4",
5355
"ruff>=0.8.5",
54-
"trio>=0.26.2",
5556
"pytest-flakefinder>=1.1.0",
5657
"pytest-xdist>=3.6.1",
5758
"pytest-examples>=0.0.14",
@@ -122,5 +123,8 @@ filterwarnings = [
122123
# This should be fixed on Uvicorn's side.
123124
"ignore::DeprecationWarning:websockets",
124125
"ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning",
125-
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel"
126+
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel",
127+
# This is to avoid test failures on Trio due to httpx's failure to explicitly close
128+
# async generators
129+
"ignore::pytest.PytestUnraisableExceptionWarning"
126130
]

src/mcp/client/streamable_http.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
import logging
1010
from collections.abc import AsyncGenerator, Awaitable, Callable
11-
from contextlib import asynccontextmanager
11+
from contextlib import aclosing, asynccontextmanager
1212
from dataclasses import dataclass
1313
from datetime import timedelta
14+
from typing import cast
1415

1516
import anyio
1617
import httpx
@@ -284,16 +285,18 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte
284285
"""Handle SSE response from the server."""
285286
try:
286287
event_source = EventSource(response)
287-
async for sse in event_source.aiter_sse():
288-
is_complete = await self._handle_sse_event(
289-
sse,
290-
ctx.read_stream_writer,
291-
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
292-
)
293-
# If the SSE event indicates completion, like returning respose/error
294-
# break the loop
295-
if is_complete:
296-
break
288+
sse_iter = cast(AsyncGenerator[ServerSentEvent], event_source.aiter_sse())
289+
async with aclosing(sse_iter) as items:
290+
async for sse in items:
291+
is_complete = await self._handle_sse_event(
292+
sse,
293+
ctx.read_stream_writer,
294+
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
295+
)
296+
# If the SSE event indicates completion, like returning respose/error
297+
# break the loop
298+
if is_complete:
299+
break
297300
except Exception as e:
298301
logger.exception("Error reading SSE stream:")
299302
await ctx.read_stream_writer.send(e)
@@ -434,15 +437,16 @@ async def streamablehttp_client(
434437
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
435438
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
436439

437-
async with anyio.create_task_group() as tg:
438-
try:
439-
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
440+
try:
441+
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
440442

441-
async with httpx_client_factory(
442-
headers=transport.request_headers,
443-
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
444-
auth=transport.auth,
445-
) as client:
443+
async with create_mcp_http_client(
444+
headers=transport.request_headers,
445+
timeout=httpx.Timeout(
446+
transport.timeout, read=transport.sse_read_timeout
447+
),
448+
) as client:
449+
async with anyio.create_task_group() as tg:
446450
# Define callbacks that need access to tg
447451
def start_get_stream() -> None:
448452
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
@@ -467,6 +471,6 @@ def start_get_stream() -> None:
467471
if transport.session_id and terminate_on_close:
468472
await transport.terminate_session(client)
469473
tg.cancel_scope.cancel()
470-
finally:
471-
await read_stream_writer.aclose()
472-
await write_stream.aclose()
474+
finally:
475+
await read_stream_writer.aclose()
476+
await write_stream.aclose()

src/mcp/server/session.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4444
import anyio.lowlevel
4545
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4646
from pydantic import AnyUrl
47+
from typing_extensions import Self
4748

4849
import mcp.types as types
4950
from mcp.server.models import InitializationOptions
@@ -93,10 +94,16 @@ def __init__(
9394
)
9495

9596
self._init_options = init_options
96-
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
97-
ServerRequestResponder
98-
](0)
99-
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
97+
98+
async def __aenter__(self) -> Self:
99+
await super().__aenter__()
100+
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
101+
anyio.create_memory_object_stream[ServerRequestResponder](0)
102+
)
103+
self._exit_stack.push_async_callback(
104+
self._incoming_message_stream_reader.aclose
105+
)
106+
return self
100107

101108
@property
102109
def client_params(self) -> types.InitializeRequestParams | None:

src/mcp/shared/session.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import anyio
99
import httpx
10+
from anyio.abc import TaskGroup
1011
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1112
from pydantic import BaseModel
1213
from typing_extensions import Self
@@ -177,6 +178,8 @@ class BaseSession(
177178
_request_id: int
178179
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
179180
_progress_callbacks: dict[RequestId, ProgressFnT]
181+
_exit_stack: AsyncExitStack
182+
_task_group: TaskGroup
180183

181184
def __init__(
182185
self,
@@ -196,12 +199,19 @@ def __init__(
196199
self._session_read_timeout_seconds = read_timeout_seconds
197200
self._in_flight = {}
198201
self._progress_callbacks = {}
199-
self._exit_stack = AsyncExitStack()
200202

201203
async def __aenter__(self) -> Self:
202-
self._task_group = anyio.create_task_group()
203-
await self._task_group.__aenter__()
204-
self._task_group.start_soon(self._receive_loop)
204+
async with AsyncExitStack() as exit_stack:
205+
self._task_group = await exit_stack.enter_async_context(
206+
anyio.create_task_group()
207+
)
208+
self._task_group.start_soon(self._receive_loop)
209+
# Using BaseSession as a context manager should not block on exit (this
210+
# would be very surprising behavior), so make sure to cancel the tasks
211+
# in the task group.
212+
exit_stack.callback(self._task_group.cancel_scope.cancel)
213+
self._exit_stack = exit_stack.pop_all()
214+
205215
return self
206216

207217
async def __aexit__(
@@ -210,12 +220,7 @@ async def __aexit__(
210220
exc_val: BaseException | None,
211221
exc_tb: TracebackType | None,
212222
) -> bool | None:
213-
await self._exit_stack.aclose()
214-
# Using BaseSession as a context manager should not block on exit (this
215-
# would be very surprising behavior), so make sure to cancel the tasks
216-
# in the task group.
217-
self._task_group.cancel_scope.cancel()
218-
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
223+
return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
219224

220225
async def send_request(
221226
self,

tests/client/test_auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def oauth_token():
9999

100100

101101
@pytest.fixture
102-
async def oauth_provider(client_metadata, mock_storage):
102+
def oauth_provider(client_metadata, mock_storage):
103103
async def mock_redirect_handler(url: str) -> None:
104104
pass
105105

tests/client/test_session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,15 +334,15 @@ async def mock_server():
334334
)
335335

336336
async with (
337+
client_to_server_send,
338+
client_to_server_receive,
339+
server_to_client_send,
340+
server_to_client_receive,
337341
ClientSession(
338342
server_to_client_receive,
339343
client_to_server_send,
340344
) as session,
341345
anyio.create_task_group() as tg,
342-
client_to_server_send,
343-
client_to_server_receive,
344-
server_to_client_send,
345-
server_to_client_receive,
346346
):
347347
tg.start_soon(mock_server)
348348

tests/conftest.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

tests/shared/test_streamable_http.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,17 @@ async def replay_events_after(
8787
"""Replay events after the specified ID."""
8888
# Find the index of the last event ID
8989
start_index = None
90-
for i, (_, event_id, _) in enumerate(self._events):
90+
stream_id = None
91+
for i, (stream_id_, event_id, _) in enumerate(self._events):
9192
if event_id == last_event_id:
9293
start_index = i + 1
94+
stream_id = stream_id_
9395
break
9496

9597
if start_index is None:
9698
# If event ID not found, start from beginning
9799
start_index = 0
98100

99-
stream_id = None
100101
# Replay events
101102
for _, event_id, message in self._events[start_index:]:
102103
await send_callback(EventMessage(message, event_id))
@@ -1003,7 +1004,8 @@ async def test_streamablehttp_client_resumption(event_server):
10031004
captured_session_id = None
10041005
captured_resumption_token = None
10051006
captured_notifications = []
1006-
tool_started = False
1007+
tool_started_event = anyio.Event()
1008+
session_resumption_token_received_event = anyio.Event()
10071009

10081010
async def message_handler(
10091011
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
@@ -1013,12 +1015,12 @@ async def message_handler(
10131015
# Look for our special notification that indicates the tool is running
10141016
if isinstance(message.root, types.LoggingMessageNotification):
10151017
if message.root.params.data == "Tool started":
1016-
nonlocal tool_started
1017-
tool_started = True
1018+
tool_started_event.set()
10181019

10191020
async def on_resumption_token_update(token: str) -> None:
10201021
nonlocal captured_resumption_token
10211022
captured_resumption_token = token
1023+
session_resumption_token_received_event.set()
10221024

10231025
# First, start the client session and begin the long-running tool
10241026
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
@@ -1055,8 +1057,8 @@ async def run_tool():
10551057

10561058
# Wait for the tool to start and at least one notification
10571059
# and then kill the task group
1058-
while not tool_started or not captured_resumption_token:
1059-
await anyio.sleep(0.1)
1060+
await tool_started_event.wait()
1061+
await session_resumption_token_received_event.wait()
10601062
tg.cancel_scope.cancel()
10611063

10621064
# Store pre notifications and clear the captured notifications

tests/shared/test_ws.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
SERVER_NAME = "test_server_for_WS"
2929

30+
pytestmark = pytest.mark.parametrize("anyio_backend", ["asyncio"])
31+
3032

3133
@pytest.fixture
3234
def server_port() -> int:

uv.lock

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)