diff --git a/mcp/streamable.go b/mcp/streamable.go index 4ab343b2..95894052 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -40,11 +40,7 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions - onTransportDeletion func(sessionID string) // for testing only - - mu sync.Mutex - // TODO: we should store the ServerSession along with the transport, because - // we need to cancel keepalive requests when closing the transport. + mu sync.Mutex transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } @@ -67,6 +63,11 @@ type StreamableHTTPOptions struct { // // [ยง2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server JSONResponse bool + + // OnSessionClose is a callback function that is invoked when a [ServerSession] + // is closed. This happens when a session is ended explicitly by the MCP client + // or when it is interrupted due to a timeout or other errors. + OnSessionClose func(sessionID string) } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -153,7 +154,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.mu.Lock() delete(h.transports, transport.SessionID) h.mu.Unlock() - transport.connection.Close() + // TODO: consider logging this error + _ = transport.session.Close() } w.WriteHeader(http.StatusNoContent) return @@ -286,8 +288,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.mu.Lock() delete(h.transports, transport.SessionID) h.mu.Unlock() - if h.onTransportDeletion != nil { - h.onTransportDeletion(transport.SessionID) + if h.opts.OnSessionClose != nil { + h.opts.OnSessionClose(transport.SessionID) } }, } @@ -307,6 +309,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } else { // Otherwise, save the transport so that it can be reused h.mu.Lock() + transport.session = ss h.transports[transport.SessionID] = transport h.mu.Unlock() } @@ -369,6 +372,9 @@ type StreamableServerTransport struct { // connection is non-nil if and only if the transport has been connected. connection *streamableServerConn + + // the server session associated with this transport. + session *ServerSession } // Connect implements the [Transport] interface. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 3b967f8f..cd3abfdf 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -103,7 +103,7 @@ func TestStreamableTransports(t *testing.T) { } handler.ServeHTTP(w, r) })) - defer httpServer.Close() + t.Cleanup(func() { httpServer.Close() }) // Create a client and connect it to the server using our StreamableClientTransport. // Check that all requests honor a custom client. @@ -130,7 +130,12 @@ func TestStreamableTransports(t *testing.T) { if err != nil { t.Fatalf("client.Connect() failed: %v", err) } - defer session.Close() + t.Cleanup(func() { + err := session.Close() + if err != nil { + t.Errorf("session.Close() failed: %v", err) + } + }) sid := session.ID() if sid == "" { t.Fatalf("empty session ID") @@ -220,7 +225,7 @@ func TestStreamableServerShutdown(t *testing.T) { httpServer := httptest.NewUnstartedServer(handler) httpServer.Config.RegisterOnShutdown(func() { for session := range server.Sessions() { - session.Close() + _ = session.Close() } }) httpServer.Start() @@ -429,10 +434,13 @@ func TestServerTransportCleanup(t *testing.T) { }, }) - handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) - handler.onTransportDeletion = func(sessionID string) { - chans[sessionID] <- struct{}{} - } + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{ + OnSessionClose: func(sessionID string) { + chans[sessionID] <- struct{}{} + }, + }, + ) httpServer := httptest.NewServer(handler) defer httpServer.Close() @@ -1430,3 +1438,41 @@ func TestStreamableGET(t *testing.T) { t.Errorf("GET with session ID: got status %d, want %d", got, want) } } + +// TestStreamableHTTPHandler_OnSessionClose_SessionDeletion tests that the +// OnSessionClose callback is called when the client closes the session. +func TestStreamableHTTPHandler_OnSessionClose_SessionDeletion(t *testing.T) { + var closedSessions []string + + server := NewServer(testImpl, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + OnSessionClose: func(sessionID string) { + closedSessions = append(closedSessions, sessionID) + }, + }) + + httpServer := httptest.NewServer(handler) + t.Cleanup(httpServer.Close) + + ctx := context.Background() + client := NewClient(testImpl, nil) + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + + sessionID := session.ID() + t.Log("Closing client session") + err = session.Close() + if err != nil { + t.Fatalf("session.Close() failed: %v", err) + } + + if len(closedSessions) != 1 { + t.Fatalf("got %d closed sessions, want 1", len(closedSessions)) + } + if closedSessions[0] != sessionID { + t.Fatalf("got session ID %q, want %q", closedSessions[0], sessionID) + } +}