Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,23 @@ type serverConnection interface {

// A StdioTransport is a [Transport] that communicates over stdin/stdout using
// newline-delimited JSON.
type StdioTransport struct{}
type StdioTransport struct {
In io.ReadCloser
Out io.WriteCloser
}

// Connect implements the [Transport] interface.
func (*StdioTransport) Connect(context.Context) (Connection, error) {
return newIOConn(rwc{os.Stdin, os.Stdout}), nil
func (t *StdioTransport) Connect(context.Context) (Connection, error) {
in := t.In
out := t.Out

if in == nil {
in = os.Stdin
}
if out == nil {
out = os.Stdout
}
return newIOConn(rwc{in, out}), nil
}

// An InMemoryTransport is a [Transport] that communicates over an in-memory
Expand Down
116 changes: 116 additions & 0 deletions mcp/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,119 @@ func TestIOConnRead(t *testing.T) {
})
}
}

func TestStdioTransport(t *testing.T) {
tests := []struct {
name string
setupIn func() io.ReadCloser
setupOut func() io.WriteCloser
wantErr bool
}{
{
name: "defaults_use_stdin_stdout",
setupIn: func() io.ReadCloser { return nil },
setupOut: func() io.WriteCloser { return nil },
wantErr: false,
},
{
name: "custom_streams",
setupIn: func() io.ReadCloser { r, _ := io.Pipe(); return r },
setupOut: func() io.WriteCloser { _, w := io.Pipe(); return w },
wantErr: false,
},
{
name: "partial_custom_in_only",
setupIn: func() io.ReadCloser { return io.NopCloser(strings.NewReader("")) },
setupOut: func() io.WriteCloser { return nil },
wantErr: false,
},
{
name: "partial_custom_out_only",
setupIn: func() io.ReadCloser { return nil },
setupOut: func() io.WriteCloser { _, w := io.Pipe(); return w },
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := &StdioTransport{
In: tt.setupIn(),
Out: tt.setupOut(),
}

conn, err := transport.Connect(context.Background())
if (err != nil) != tt.wantErr {
t.Errorf("StdioTransport.Connect() error = %v, wantErr %v", err, tt.wantErr)
return
}

if conn == nil {
t.Error("StdioTransport.Connect() returned nil connection")
return
}

defer conn.Close()
})
}
}

func TestStdioTransportDefaults(t *testing.T) {
transport := &StdioTransport{}

if transport.In != nil {
t.Error("StdioTransport{}.In should be nil (uses default)")
}

if transport.Out != nil {
t.Error("StdioTransport{}.Out should be nil (uses default)")
}

conn, err := transport.Connect(context.Background())
if err != nil {
t.Fatalf("StdioTransport{}.Connect() failed: %v", err)
}
defer conn.Close()
}

func TestStdioTransportReadWrite(t *testing.T) {
ctx := context.Background()
r, w := io.Pipe()
defer r.Close()
defer w.Close()

transport := &StdioTransport{
In: r,
Out: w,
}

conn, err := transport.Connect(ctx)
if err != nil {
t.Fatalf("StdioTransport.Connect() failed: %v", err)
}
defer conn.Close()

// Test that we can write a message and it gets transmitted
testMsg := &jsonrpc.Request{
ID: jsonrpc2.Int64ID(1),
Method: "test",
Params: nil,
}

// Write message in a goroutine since pipe may block
go func() {
if err := conn.Write(ctx, testMsg); err != nil {
t.Errorf("conn.Write() failed: %v", err)
}
}()

// Read the message back
receivedMsg, err := conn.Read(ctx)
if err != nil {
t.Fatalf("conn.Read() failed: %v", err)
}

if req, ok := receivedMsg.(*jsonrpc.Request); !ok || req.Method != "test" {
t.Errorf("Expected request with method 'test', got %v", receivedMsg)
}
}