Skip to content
30 changes: 25 additions & 5 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ type StreamableHTTPHandler struct {
getServer func(*http.Request) *Server
opts StreamableHTTPOptions

onTransportDeletion func(sessionID string) // for testing only

mu sync.Mutex
// TODO: we should store the ServerSession along with the transport, because
// we need to cancel keepalive requests when closing the transport.
Expand Down Expand Up @@ -77,6 +75,11 @@ type StreamableHTTPOptions struct {
//
// [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
JSONResponse bool

// OnConnectionClose is a callback function that is invoked when a [Connection]
// is closed. A connection is closed when the session is ended explicitly by
// the client or when it is interrupted due to a timeout or other errors.
OnConnectionClose func(sessionID string)
}

// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
Expand Down Expand Up @@ -166,7 +169,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
transport.connection.Close()
_ = transport.Close()
}
w.WriteHeader(http.StatusNoContent)
return
Expand Down Expand Up @@ -299,8 +302,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
if h.onTransportDeletion != nil {
h.onTransportDeletion(transport.SessionID)
if h.opts.OnConnectionClose != nil {
h.opts.OnConnectionClose(transport.SessionID)
}
},
}
Expand All @@ -320,6 +323,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
} else {
// Otherwise, save the transport so that it can be reused
h.mu.Lock()
transport.session = ss
h.transports[transport.SessionID] = transport
h.mu.Unlock()
}
Expand Down Expand Up @@ -382,6 +386,9 @@ type StreamableServerTransport struct {

// connection is non-nil if and only if the transport has been connected.
connection *streamableServerConn

// the server session associated with this transport
session *ServerSession
}

// Connect implements the [Transport] interface.
Expand Down Expand Up @@ -563,6 +570,19 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R
}
}

// Close releases resources related to this transport if it has already been connected.
func (t *StreamableServerTransport) Close() error {
var sessionErr, connErr error
if t.session != nil {
sessionErr = t.session.Close()
}
if t.connection != nil {
connErr = t.connection.Close()
}

return errors.Join(sessionErr, connErr)
}

// serveGET streams messages to a hanging http GET, with stream ID and last
// message parsed from the Last-Event-ID header.
//
Expand Down
54 changes: 48 additions & 6 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ func TestStreamableTransports(t *testing.T) {
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer session.Close()
t.Cleanup(func() {
err := session.Close()
if err != nil {
t.Errorf("session.Close() failed: %v", err)
}
})
sid := session.ID()
if sid == "" {
t.Fatalf("empty session ID")
Expand Down Expand Up @@ -220,7 +225,7 @@ func TestStreamableServerShutdown(t *testing.T) {
httpServer := httptest.NewUnstartedServer(handler)
httpServer.Config.RegisterOnShutdown(func() {
for session := range server.Sessions() {
session.Close()
_ = session.Close()
}
})
httpServer.Start()
Expand Down Expand Up @@ -428,12 +433,11 @@ func TestServerTransportCleanup(t *testing.T) {
chans[fmt.Sprint(id)] = make(chan struct{}, 1)
return fmt.Sprint(id)
},
OnConnectionClose: func(sessionID string) {
chans[sessionID] <- struct{}{}
},
})

handler.onTransportDeletion = func(sessionID string) {
chans[sessionID] <- struct{}{}
}

httpServer := httptest.NewServer(handler)
defer httpServer.Close()

Expand Down Expand Up @@ -1423,3 +1427,41 @@ func TestStreamableGET(t *testing.T) {
t.Errorf("GET with session ID: got status %d, want %d", got, want)
}
}

// TestStreamableHTTPHandler_OnConnectionClose_SessionDeletion tests that the
// OnConnectionClose callback is called when the client closes the session.
func TestStreamableHTTPHandler_OnConnectionClose_SessionDeletion(t *testing.T) {
var closedConnections []string

server := NewServer(testImpl, nil)
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{
OnConnectionClose: func(sessionID string) {
closedConnections = append(closedConnections, sessionID)
},
})

httpServer := httptest.NewServer(handler)
t.Cleanup(httpServer.Close)

ctx := context.Background()
client := NewClient(testImpl, nil)
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}

sessionID := session.ID()
t.Log("Closing client session")
err = session.Close()
if err != nil {
t.Fatalf("session.Close() failed: %v", err)
}

if len(closedConnections) != 1 {
t.Fatalf("got %d connections, want 1", len(closedConnections))
}
if closedConnections[0] != sessionID {
t.Fatalf("got session ID %q, want %q", closedConnections[0], sessionID)
}
}