Skip to content

Commit 5b44363

Browse files
committed
fix race (#232)
1 parent d9bf37e commit 5b44363

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

mcp/conformance_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func runServerTest(t *testing.T, test *conformanceTest) {
135135
// Connect the server, and connect the client stream,
136136
// but don't connect an actual client.
137137
cTransport, sTransport := NewInMemoryTransports()
138-
ss, err := s.Connect(ctx, sTransport)
138+
ss, err := s.Connect(ctx, sTransport, nil)
139139
if err != nil {
140140
t.Fatal(err)
141141
}

mcp/server.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ func (s *Server) disconnect(cc *ServerSession) {
559559
}
560560
}
561561

562-
type ServerSessionOptions struct {
562+
type SessionOptions struct {
563563
SessionID string
564564
SessionState *SessionState
565565
SessionStore SessionStore
@@ -571,7 +571,7 @@ type ServerSessionOptions struct {
571571
// It returns a connection object that may be used to terminate the connection
572572
// (with [Connection.Close]), or await client termination (with
573573
// [Connection.Wait]).
574-
func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) {
574+
func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions) (*ServerSession, error) {
575575
if opts != nil && opts.SessionState == nil && opts.SessionStore != nil {
576576
return nil, errors.New("ServerSessionOptions has store but no state")
577577
}
@@ -634,7 +634,7 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot
634634
type ServerSession struct {
635635
server *Server
636636
conn *jsonrpc2.Connection
637-
opts ServerSessionOptions
637+
opts SessionOptions
638638
mu sync.Mutex
639639
logLevel LoggingLevel
640640
_initialized bool

mcp/streamable.go

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,26 +117,26 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
117117
return
118118
}
119119

120-
var session *StreamableServerTransport
120+
var transport *StreamableServerTransport
121121
sessionID := req.Header.Get(sessionIDHeader)
122122
if sessionID != "" {
123123
h.transportMu.Lock()
124-
session, _ = h.transports[sessionID]
124+
transport, _ = h.transports[sessionID]
125125
h.transportMu.Unlock()
126126
}
127127

128128
// TODO(rfindley): simplify the locking so that each request has only one
129129
// critical section.
130130
if req.Method == http.MethodDelete {
131-
if session == nil {
131+
if transport == nil {
132132
// => Mcp-Session-Id was not set; else we'd have returned NotFound above.
133133
http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest)
134134
return
135135
}
136136
h.transportMu.Lock()
137-
delete(h.transports, session.sessionID)
137+
delete(h.transports, transport.sessionID)
138138
h.transportMu.Unlock()
139-
session.Close()
139+
transport.Close()
140140
w.WriteHeader(http.StatusNoContent)
141141
return
142142
}
@@ -149,7 +149,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
149149
return
150150
}
151151

152-
if session == nil {
152+
if transport == nil {
153153
var state *SessionState
154154
var err error
155155
if sessionID != "" {
@@ -162,16 +162,15 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
162162
http.Error(w, fmt.Sprintf("SessionStore.Load(%q): %v", sessionID, err), http.StatusInternalServerError)
163163
return
164164
}
165-
session = NewStreamableServerTransport(sessionID, nil)
166165
} else {
167166
state = &SessionState{}
168167
sessionID = randText()
169168
if err := h.opts.SessionStore.Store(req.Context(), sessionID, state); err != nil {
170169
http.Error(w, fmt.Sprintf("SessionStore.Store, new session: %v", err), http.StatusInternalServerError)
171170
return
172171
}
173-
session = NewStreamableServerTransport(sessionID, nil)
174172
}
173+
transport = NewStreamableServerTransport(sessionID, nil)
175174
server := h.getServer(req)
176175
if server == nil {
177176
http.Error(w, "no server available", http.StatusBadRequest)
@@ -180,7 +179,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
180179
// Pass req.Context() here, to allow middleware to add context values.
181180
// The context is detached in the jsonrpc2 library when handling the
182181
// long-running stream.
183-
_, err = server.Connect(req.Context(), session, &ServerSessionOptions{
182+
// TODO: rename SessionOptions to ConnectOptions?
183+
_, err = server.Connect(req.Context(), transport, &SessionOptions{
184184
SessionID: sessionID,
185185
SessionState: state,
186186
SessionStore: h.opts.SessionStore,
@@ -190,11 +190,14 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
190190
return
191191
}
192192
h.transportMu.Lock()
193-
h.transports[session.sessionID] = session
193+
// Check in case another request with the same stored session ID got here first.
194+
if _, ok := h.transports[transport.sessionID]; !ok {
195+
h.transports[transport.sessionID] = transport
196+
}
194197
h.transportMu.Unlock()
195198
}
196199

197-
session.ServeHTTP(w, req)
200+
transport.ServeHTTP(w, req)
198201
}
199202

200203
type StreamableServerTransportOptions struct {

0 commit comments

Comments
 (0)