Skip to content

Commit 468b381

Browse files
authored
feat(server): add option to customize basePath (#45)
* feat: add option to customize basePath * tests
1 parent 490b11c commit 468b381

File tree

2 files changed

+96
-8
lines changed

2 files changed

+96
-8
lines changed

server/sse.go

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"net/http"
88
"net/http/httptest"
9+
"strings"
910
"sync"
1011

1112
"github.com/google/uuid"
@@ -25,6 +26,7 @@ type sseSession struct {
2526
type SSEServer struct {
2627
server *MCPServer
2728
baseURL string
29+
basePath string
2830
messageEndpoint string
2931
sseEndpoint string
3032
sessions sync.Map
@@ -41,6 +43,18 @@ func WithBaseURL(baseURL string) Option {
4143
}
4244
}
4345

46+
// Add a new option for setting base path
47+
func WithBasePath(basePath string) Option {
48+
return func(s *SSEServer) {
49+
// Ensure the path starts with / and doesn't end with /
50+
if !strings.HasPrefix(basePath, "/") {
51+
basePath = "/" + basePath
52+
}
53+
s.basePath = strings.TrimSuffix(basePath, "/")
54+
s.baseURL = s.baseURL + s.basePath
55+
}
56+
}
57+
4458
// WithMessageEndpoint sets the message endpoint path
4559
func WithMessageEndpoint(endpoint string) Option {
4660
return func(s *SSEServer) {
@@ -68,6 +82,7 @@ func NewSSEServer(server *MCPServer, opts ...Option) *SSEServer {
6882
server: server,
6983
sseEndpoint: "/sse",
7084
messageEndpoint: "/message",
85+
basePath: "",
7186
}
7287

7388
// Apply all options
@@ -299,12 +314,22 @@ func (s *SSEServer) SendEventToSession(
299314

300315
// ServeHTTP implements the http.Handler interface.
301316
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
302-
switch r.URL.Path {
303-
case s.sseEndpoint:
317+
path := r.URL.Path
318+
319+
// Construct the full SSE and message paths
320+
ssePath := s.basePath + s.sseEndpoint
321+
messagePath := s.basePath + s.messageEndpoint
322+
323+
// Use exact path matching rather than Contains
324+
if path == ssePath {
304325
s.handleSSE(w, r)
305-
case s.messageEndpoint:
326+
return
327+
}
328+
329+
if path == messagePath {
306330
s.handleMessage(w, r)
307-
default:
308-
http.NotFound(w, r)
331+
return
309332
}
333+
334+
http.NotFound(w, r)
310335
}

server/sse_test.go

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,29 @@ import (
1616
func TestSSEServer(t *testing.T) {
1717
t.Run("Can instantiate", func(t *testing.T) {
1818
mcpServer := NewMCPServer("test", "1.0.0")
19-
sseServer := NewSSEServer(mcpServer, WithBaseURL("http://localhost:8080"))
19+
sseServer := NewSSEServer(mcpServer,
20+
WithBaseURL("http://localhost:8080"),
21+
WithBasePath("/mcp"),
22+
)
2023

2124
if sseServer == nil {
2225
t.Error("SSEServer should not be nil")
2326
}
2427
if sseServer.server == nil {
2528
t.Error("MCPServer should not be nil")
2629
}
27-
if sseServer.baseURL != "http://localhost:8080" {
30+
if sseServer.baseURL != "http://localhost:8080/mcp" {
2831
t.Errorf(
29-
"Expected baseURL http://localhost:8080, got %s",
32+
"Expected baseURL http://localhost:8080/mcp, got %s",
3033
sseServer.baseURL,
3134
)
3235
}
36+
if sseServer.basePath != "/mcp" {
37+
t.Errorf(
38+
"Expected basePath /mcp, got %s",
39+
sseServer.basePath,
40+
)
41+
}
3342
})
3443

3544
t.Run("Can send and receive messages", func(t *testing.T) {
@@ -405,4 +414,58 @@ func TestSSEServer(t *testing.T) {
405414
// Clean up SSE connection
406415
cancel()
407416
})
417+
418+
t.Run("works as http.Handler with custom basePath", func(t *testing.T) {
419+
mcpServer := NewMCPServer("test", "1.0.0")
420+
sseServer := NewSSEServer(mcpServer, WithBasePath("/mcp"))
421+
422+
ts := httptest.NewServer(sseServer)
423+
defer ts.Close()
424+
425+
// Test 404 for unknown path first (simpler case)
426+
resp, err := http.Get(fmt.Sprintf("%s/sse", ts.URL))
427+
if err != nil {
428+
t.Fatalf("Failed to make request: %v", err)
429+
}
430+
defer resp.Body.Close()
431+
if resp.StatusCode != http.StatusNotFound {
432+
t.Errorf("Expected status 404, got %d", resp.StatusCode)
433+
}
434+
435+
// Test SSE endpoint with proper cleanup
436+
ctx, cancel := context.WithCancel(context.Background())
437+
defer cancel()
438+
439+
sseURL := fmt.Sprintf("%s/sse", ts.URL+sseServer.basePath)
440+
req, err := http.NewRequestWithContext(ctx, "GET", sseURL, nil)
441+
if err != nil {
442+
t.Fatalf("Failed to create request: %v", err)
443+
}
444+
445+
resp, err = http.DefaultClient.Do(req)
446+
if err != nil {
447+
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
448+
}
449+
defer resp.Body.Close()
450+
451+
if resp.StatusCode != http.StatusOK {
452+
t.Errorf("Expected status 200, got %d", resp.StatusCode)
453+
}
454+
455+
// Read initial message in goroutine
456+
done := make(chan struct{})
457+
go func() {
458+
defer close(done)
459+
buf := make([]byte, 1024)
460+
_, err := resp.Body.Read(buf)
461+
if err != nil && err.Error() != "context canceled" {
462+
t.Errorf("Failed to read from SSE stream: %v", err)
463+
}
464+
}()
465+
466+
// Wait briefly for initial response then cancel
467+
time.Sleep(100 * time.Millisecond)
468+
cancel()
469+
<-done
470+
})
408471
}

0 commit comments

Comments
 (0)