Skip to content
22 changes: 14 additions & 8 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ type StreamableHTTPHandler struct {
getServer func(*http.Request) *Server
opts StreamableHTTPOptions

onTransportDeletion func(sessionID string) // for testing only

mu sync.Mutex
// TODO: we should store the ServerSession along with the transport, because
// we need to cancel keepalive requests when closing the transport.
mu sync.Mutex
transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
}

Expand All @@ -67,6 +63,11 @@ type StreamableHTTPOptions struct {
//
// [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
JSONResponse bool

// OnSessionClose is a callback function that is invoked when a [ServerSession]
// is closed. This happens when a session is ended explicitly by the MCP client
// or when it is interrupted due to a timeout or other errors.
OnSessionClose func(sessionID string)
}

// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
Expand Down Expand Up @@ -153,7 +154,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
transport.connection.Close()
// TODO: consider logging this error
_ = transport.session.Close()
}
w.WriteHeader(http.StatusNoContent)
return
Expand Down Expand Up @@ -286,8 +288,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
if h.onTransportDeletion != nil {
h.onTransportDeletion(transport.SessionID)
if h.opts.OnSessionClose != nil {
h.opts.OnSessionClose(transport.SessionID)
}
},
}
Expand All @@ -307,6 +309,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
} else {
// Otherwise, save the transport so that it can be reused
h.mu.Lock()
transport.session = ss
h.transports[transport.SessionID] = transport
h.mu.Unlock()
}
Expand Down Expand Up @@ -369,6 +372,9 @@ type StreamableServerTransport struct {

// connection is non-nil if and only if the transport has been connected.
connection *streamableServerConn

// the server session associated with this transport.
session *ServerSession
}

// Connect implements the [Transport] interface.
Expand Down
60 changes: 53 additions & 7 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestStreamableTransports(t *testing.T) {
}
handler.ServeHTTP(w, r)
}))
defer httpServer.Close()
t.Cleanup(func() { httpServer.Close() })

// Create a client and connect it to the server using our StreamableClientTransport.
// Check that all requests honor a custom client.
Expand All @@ -130,7 +130,12 @@ func TestStreamableTransports(t *testing.T) {
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer session.Close()
t.Cleanup(func() {
err := session.Close()
if err != nil {
t.Errorf("session.Close() failed: %v", err)
}
})
sid := session.ID()
if sid == "" {
t.Fatalf("empty session ID")
Expand Down Expand Up @@ -220,7 +225,7 @@ func TestStreamableServerShutdown(t *testing.T) {
httpServer := httptest.NewUnstartedServer(handler)
httpServer.Config.RegisterOnShutdown(func() {
for session := range server.Sessions() {
session.Close()
_ = session.Close()
}
})
httpServer.Start()
Expand Down Expand Up @@ -429,10 +434,13 @@ func TestServerTransportCleanup(t *testing.T) {
},
})

handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)
handler.onTransportDeletion = func(sessionID string) {
chans[sessionID] <- struct{}{}
}
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server },
&StreamableHTTPOptions{
OnSessionClose: func(sessionID string) {
chans[sessionID] <- struct{}{}
},
},
)

httpServer := httptest.NewServer(handler)
defer httpServer.Close()
Expand Down Expand Up @@ -1430,3 +1438,41 @@ func TestStreamableGET(t *testing.T) {
t.Errorf("GET with session ID: got status %d, want %d", got, want)
}
}

// TestStreamableHTTPHandler_OnSessionClose_SessionDeletion tests that the
// OnSessionClose callback is called when the client closes the session.
func TestStreamableHTTPHandler_OnSessionClose_SessionDeletion(t *testing.T) {
var closedSessions []string

server := NewServer(testImpl, nil)
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{
OnSessionClose: func(sessionID string) {
closedSessions = append(closedSessions, sessionID)
},
})

httpServer := httptest.NewServer(handler)
t.Cleanup(httpServer.Close)

ctx := context.Background()
client := NewClient(testImpl, nil)
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}

sessionID := session.ID()
t.Log("Closing client session")
err = session.Close()
if err != nil {
t.Fatalf("session.Close() failed: %v", err)
}

if len(closedSessions) != 1 {
t.Fatalf("got %d closed sessions, want 1", len(closedSessions))
}
if closedSessions[0] != sessionID {
t.Fatalf("got session ID %q, want %q", closedSessions[0], sessionID)
}
}