diff --git a/ModelContextProtocol.slnx b/ModelContextProtocol.slnx index d5588fdf..a70e3e31 100644 --- a/ModelContextProtocol.slnx +++ b/ModelContextProtocol.slnx @@ -56,6 +56,7 @@ + diff --git a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs index c026acb9..3ec0d288 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs @@ -90,13 +90,13 @@ public async Task ConnectAsync(CancellationToken cancellationToken = #if NET foreach (string arg in arguments) { - startInfo.ArgumentList.Add(arg); + startInfo.ArgumentList.Add(EscapeArgumentString(arg)); } #else StringBuilder argsBuilder = new(); foreach (string arg in arguments) { - PasteArguments.AppendArgument(argsBuilder, arg); + PasteArguments.AppendArgument(argsBuilder, EscapeArgumentString(arg)); } startInfo.Arguments = argsBuilder.ToString(); @@ -236,6 +236,26 @@ internal static bool HasExited(Process process) } } + private static string EscapeArgumentString(string argument) => + RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && !ContainsWhitespaceRegex.IsMatch(argument) ? + WindowsCliSpecialArgumentsRegex.Replace(argument, static match => "^" + match.Value) : + argument; + + private const string WindowsCliSpecialArgumentsRegexString = "[&^><|]"; + +#if NET + private static Regex WindowsCliSpecialArgumentsRegex => GetWindowsCliSpecialArgumentsRegex(); + private static Regex ContainsWhitespaceRegex => GetContainsWhitespaceRegex(); + + [GeneratedRegex(WindowsCliSpecialArgumentsRegexString, RegexOptions.CultureInvariant)] + private static partial Regex GetWindowsCliSpecialArgumentsRegex(); + [GeneratedRegex(@"\s", RegexOptions.CultureInvariant)] + private static partial Regex GetContainsWhitespaceRegex(); +#else + private static Regex WindowsCliSpecialArgumentsRegex { get; } = new(WindowsCliSpecialArgumentsRegexString, RegexOptions.Compiled | RegexOptions.CultureInvariant); + private static Regex ContainsWhitespaceRegex { get; } = new(@"\s", RegexOptions.Compiled | RegexOptions.CultureInvariant); +#endif + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} connecting.")] private static partial void LogTransportConnecting(ILogger logger, string endpointName); diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 75271e34..9765ed92 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -36,13 +36,14 @@ private static async Task Main(string[] args) { Log.Logger.Information("Starting server..."); + string? cliArg = ParseCliArgument(args); McpServerOptions options = new() { Capabilities = new ServerCapabilities(), ServerInstructions = "This is a test server with only stub functionality", }; - ConfigureTools(options); + ConfigureTools(options, cliArg); ConfigureResources(options); ConfigurePrompts(options); ConfigureLogging(options); @@ -104,7 +105,7 @@ await server.SendMessageAsync(new JsonRpcNotification } } - private static void ConfigureTools(McpServerOptions options) + private static void ConfigureTools(McpServerOptions options, string? cliArg) { options.Handlers.ListToolsHandler = async (request, cancellationToken) => { @@ -199,6 +200,13 @@ private static void ConfigureTools(McpServerOptions options) Content = [new TextContentBlock { Text = $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}" }] }; } + else if (request.Params?.Name == "echoCliArg") + { + return new CallToolResult + { + Content = [new TextContentBlock { Text = cliArg ?? "null" }] + }; + } else { throw new McpException($"Unknown tool: {request.Params?.Name}", McpErrorCode.InvalidParams); @@ -522,6 +530,19 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }; } + private static string? ParseCliArgument(string[] args) + { + foreach (var arg in args) + { + if (arg.StartsWith("--cli-arg=")) + { + return arg["--cli-arg=".Length..]; + } + } + + return null; + } + const string MCP_TINY_IMAGE = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAAKsGlDQ1BJQ0MgUHJvZmlsZQAASImVlwdUU+kSgOfe9JDQEiIgJfQmSCeAlBBaAAXpYCMkAUKJMRBU7MriClZURLCs6KqIgo0idizYFsWC3QVZBNR1sWDDlXeBQ9jdd9575805c+a7c+efmf+e/z9nLgCdKZDJMlF1gCxpjjwyyI8dn5DIJvUABRiY0kBdIMyWcSMiwgCTUft3+dgGyJC9YzuU69/f/1fREImzhQBIBMbJomxhFsbHMe0TyuQ5ALg9mN9kbo5siK9gzJRjDWL8ZIhTR7hviJOHGY8fjomO5GGsDUCmCQTyVACaKeZn5wpTsTw0f4ztpSKJFGPsGbyzsmaLMMbqgiUWI8N4KD8n+S95Uv+WM1mZUyBIVfLIXoaF7C/JlmUK5v+fn+N/S1amYrSGOaa0NHlwJGaxvpAHGbNDlSxNnhI+yhLRcPwwpymCY0ZZmM1LHGWRwD9UuTZzStgop0gC+co8OfzoURZnB0SNsnx2pLJWipzHHWWBfKyuIiNG6U8T85X589Ki40Y5VxI7ZZSzM6JCx2J4Sr9cEansXywN8hurG6jce1b2X/Yr4SvX5qRFByv3LhjrXyzljuXMjlf2JhL7B4zFxCjjZTl+ylqyzAhlvDgzSOnPzo1Srs3BDuTY2gjlN0wXhESMMoRBELAhBjIhB+QggECQgBTEOeJ5Q2cUeLNl8+WS1LQcNhe7ZWI2Xyq0m8B2tHd0Bhi6syNH4j1r+C4irGtjvhWVAF4nBgcHT475Qm4BHEkCoNaO+SxnAKh3A1w5JVTIc0d8Q9cJCEAFNWCCDhiACViCLTiCK3iCLwRACIRDNCTATBBCGmRhnc+FhbAMCqAI1sNmKIOdsBv2wyE4CvVwCs7DZbgOt+AePIZ26IJX0AcfYQBBEBJCRxiIDmKImCE2iCPCQbyRACQMiUQSkCQkFZEiCmQhsgIpQoqRMmQXUokcQU4g55GrSCvyEOlAepF3yFcUh9JQJqqPmqMTUQ7KRUPRaHQGmorOQfPQfHQtWopWoAfROvQ8eh29h7ajr9B+HOBUcCycEc4Wx8HxcOG4RFwKTo5bjCvEleAqcNW4Rlwz7g6uHfca9wVPxDPwbLwt3hMfjI/BC/Fz8Ivxq/Fl+P34OvxF/B18B74P/51AJ+gRbAgeBD4hnpBKmEsoIJQQ9hJqCZcI9whdhI9EIpFFtCC6EYOJCcR04gLiauJ2Yg3xHLGV2EnsJ5FIOiQbkhcpnCQg5ZAKSFtJB0lnSbdJXaTPZBWyIdmRHEhOJEvJy8kl5APkM+Tb5G7yAEWdYkbxoIRTRJT5lHWUPZRGyk1KF2WAqkG1oHpRo6np1GXUUmo19RL1CfW9ioqKsYq7ylQVicpSlVKVwypXVDpUvtA0adY0Hm06TUFbS9tHO0d7SHtPp9PN6b70RHoOfS29kn6B/oz+WZWhaqfKVxWpLlEtV61Tva36Ro2iZqbGVZuplqdWonZM7abaa3WKurk6T12gvli9XP2E+n31fg2GhoNGuEaWxmqNAxpXNXo0SZrmmgGaIs18zd2aFzQ7GTiGCYPHEDJWMPYwLjG6mESmBZPPTGcWMQ8xW5h9WppazlqxWvO0yrVOa7WzcCxzFp+VyVrHOspqY30dpz+OO048btW46nG3x33SHq/tqy3WLtSu0b6n/VWHrROgk6GzQade56kuXtdad6ruXN0dupd0X49njvccLxxfOP7o+Ed6qJ61XqTeAr3dejf0+vUN9IP0Zfpb9S/ovzZgGfgapBtsMjhj0GvIMPQ2lBhuMjxr+JKtxeayM9ml7IvsPiM9o2AjhdEuoxajAWML4xjj5cY1xk9NqCYckxSTTSZNJn2mhqaTTReaVpk+MqOYcczSzLaYNZt9MrcwjzNfaV5v3mOhbcG3yLOosnhiSbf0sZxjWWF514poxbHKsNpudcsatXaxTrMut75pg9q42khsttu0TiBMcJ8gnVAx4b4tzZZrm2tbZdthx7ILs1tuV2/3ZqLpxMSJGyY2T/xu72Kfab/H/rGDpkOIw3KHRod3jtaOQsdyx7tOdKdApyVODU5vnW2cxc47nB+4MFwmu6x0aXL509XNVe5a7drrZuqW5LbN7T6HyYngrOZccSe4+7kvcT/l/sXD1SPH46jHH562nhmeBzx7JllMEk/aM6nTy9hL4LXLq92b7Z3k/ZN3u4+Rj8Cnwue5r4mvyHevbzfXipvOPch942fvJ/er9fvE8+At4p3zx/kH+Rf6twRoBsQElAU8CzQOTA2sCuwLcglaEHQumBAcGrwh+D5fny/kV/L7QtxCFoVcDKWFRoWWhT4Psw6ThzVORieHTN44+ckUsynSKfXhEM4P3xj+NMIiYk7EyanEqRFTy6e+iHSIXBjZHMWImhV1IOpjtF/0uujHMZYxipimWLXY6bGVsZ/i/OOK49rjJ8Yvir+eoJsgSWhIJCXGJu5N7J8WMG3ztK7pLtMLprfNsJgxb8bVmbozM2eenqU2SzDrWBIhKS7pQNI3QbigQtCfzE/eltwn5Am3CF+JfEWbRL1iL3GxuDvFK6U4pSfVK3Vjam+aT1pJ2msJT1ImeZsenL4z/VNGeMa+jMHMuMyaLHJWUtYJqaY0Q3pxtsHsebNbZTayAln7HI85m+f0yUPle7OR7BnZDTlMbDi6obBU/KDoyPXOLc/9PDd27rF5GvOk827Mt56/an53XmDezwvwC4QLmhYaLVy2sGMRd9Guxcji5MVNS0yW5C/pWhq0dP8y6rKMZb8st19evPzDirgVjfn6+UvzO38I+qGqQLVAXnB/pefKnT/if5T82LLKadXWVd8LRYXXiuyLSoq+rRauvrbGYU3pmsG1KWtb1rmu27GeuF66vm2Dz4b9xRrFecWdGydvrNvE3lS46cPmWZuvljiX7NxC3aLY0l4aVtqw1XTr+q3fytLK7pX7ldds09u2atun7aLtt3f47qjeqb+zaOfXnyQ/PdgVtKuuwryiZDdxd+7uF3ti9zT/zPm5cq/u3qK9f+6T7mvfH7n/YqVbZeUBvQPrqtAqRVXvwekHbx3yP9RQbVu9q4ZVU3QYDisOvzySdKTtaOjRpmOcY9XHzY5vq2XUFtYhdfPr+urT6tsbEhpaT4ScaGr0bKw9aXdy3ymjU+WntU6vO0M9k39m8Gze2f5zsnOvz6ee72ya1fT4QvyFuxenXmy5FHrpyuXAyxeauc1nr3hdOXXV4+qJa5xr9dddr9fdcLlR+4vLL7Utri11N91uNtzyv9XYOqn1zG2f2+fv+N+5fJd/9/q9Kfda22LaHtyffr/9gehBz8PMh28f5T4aeLz0CeFJ4VP1pyXP9J5V/Gr1a027a/vpDv+OG8+jnj/uFHa++i37t29d+S/oL0q6Dbsrexx7TvUG9t56Oe1l1yvZq4HXBb9r/L7tjeWb43/4/nGjL76v66387eC71e913u/74PyhqT+i/9nHrI8Dnwo/63ze/4Xzpflr3NfugbnfSN9K/7T6s/F76Pcng1mDgzKBXDA8CuAwRVNSAN7tA6AnADCwGYI6bWSmHhZk5D9gmOA/8cjcPSyuANWYGRqNeOcADmNqvhRAzRdgaCyK9gXUyUmpo/Pv8Kw+JAbYv8K0HECi2x6tebQU/iEjc/xf+v6nBWXWv9l/AV0EC6JTIblRAAAAeGVYSWZNTQAqAAAACAAFARIAAwAAAAEAAQAAARoABQAAAAEAAABKARsABQAAAAEAAABSASgAAwAAAAEAAgAAh2kABAAAAAEAAABaAAAAAAAAAJAAAAABAAAAkAAAAAEAAqACAAQAAAABAAAAFKADAAQAAAABAAAAFAAAAAAXNii1AAAACXBIWXMAABYlAAAWJQFJUiTwAAAB82lUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNi4wLjAiPgogICA8cmRmOlJERiB4bWxuczpyZGY9Imh0dHA6Ly93d3cudzMub3JnLzE5OTkvMDIvMjItcmRmLXN5bnRheC1ucyMiPgogICAgICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgICAgICAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyI+CiAgICAgICAgIDx0aWZmOllSZXNvbHV0aW9uPjE0NDwvdGlmZjpZUmVzb2x1dGlvbj4KICAgICAgICAgPHRpZmY6T3JpZW50YXRpb24+MTwvdGlmZjpPcmllbnRhdGlvbj4KICAgICAgICAgPHRpZmY6WFJlc29sdXRpb24+MTQ0PC90aWZmOlhSZXNvbHV0aW9uPgogICAgICAgICA8dGlmZjpSZXNvbHV0aW9uVW5pdD4yPC90aWZmOlJlc29sdXRpb25Vbml0PgogICAgICA8L3JkZjpEZXNjcmlwdGlvbj4KICAgPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KReh49gAAAjRJREFUOBGFlD2vMUEUx2clvoNCcW8hCqFAo1dKhEQpvsF9KrWEBh/ALbQ0KkInBI3SWyGPCCJEQliXgsTLefaca/bBWjvJzs6cOf/fnDkzOQJIjWm06/XKBEGgD8c6nU5VIWgBtQDPZPWtJE8O63a7LBgMMo/Hw0ql0jPjcY4RvmqXy4XMjUYDUwLtdhtmsxnYbDbI5/O0djqdFFKmsEiGZ9jP9gem0yn0ej2Yz+fg9XpfycimAD7DttstQTDKfr8Po9GIIg6Hw1Cr1RTgB+A72GAwgMPhQLBMJgNSXsFqtUI2myUo18pA6QJogefsPrLBX4QdCVatViklw+EQRFGEj88P2O12pEUGATmsXq+TaLPZ0AXgMRF2vMEqlQoJTSYTpNNpApvNZliv1/+BHDaZTAi2Wq1A3Ig0xmMej7+RcZjdbodUKkWAaDQK+GHjHPnImB88JrZIJAKFQgH2+z2BOczhcMiwRCIBgUAA+NN5BP6mj2DYff35gk6nA61WCzBn2JxO5wPM7/fLz4vD0E+OECfn8xl/0Gw2KbLxeAyLxQIsFgt8p75pDSO7h/HbpUWpewCike9WLpfB7XaDy+WCYrFI/slk8i0MnRRAUt46hPMI4vE4+Hw+ec7t9/44VgWigEeby+UgFArJWjUYOqhWG6x50rpcSfR6PVUfNOgEVRlTX0HhrZBKz4MZjUYWi8VoA+lc9H/VaRZYjBKrtXR8tlwumcFgeMWRbZpA9ORQWfVm8A/FsrLaxebd5wAAAABJRU5ErkJggg=="; } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/PlatformDetection.cs b/tests/ModelContextProtocol.Tests/PlatformDetection.cs index 1eef9942..f439147f 100644 --- a/tests/ModelContextProtocol.Tests/PlatformDetection.cs +++ b/tests/ModelContextProtocol.Tests/PlatformDetection.cs @@ -1,6 +1,9 @@ +using System.Runtime.InteropServices; + namespace ModelContextProtocol.Tests; internal static class PlatformDetection { public static bool IsMonoRuntime { get; } = Type.GetType("Mono.Runtime") is not null; + public static bool IsWindows { get; } = RuntimeInformation.IsOSPlatform(OSPlatform.Windows); } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index 48c2b953..5394ba30 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; using ModelContextProtocol.Tests.Utils; using System.Runtime.InteropServices; using System.Text; @@ -15,17 +16,14 @@ public async Task CreateAsync_ValidProcessInvalidServer_Throws() string id = Guid.NewGuid().ToString("N"); StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"] }, LoggerFactory) : - new(new() { Command = "ls", Arguments = [id] }, LoggerFactory); + new(new() { Command = "cmd", Arguments = ["/c", $"echo {id} >&2 & exit /b 1"] }, LoggerFactory) : + new(new() { Command = "sh", Arguments = ["-c", $"echo {id} >&2; exit 1"] }, LoggerFactory); - IOException e = await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); - if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - Assert.Contains(id, e.ToString()); - } + await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } - [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] + // [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] + [Fact] public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() { string id = Guid.NewGuid().ToString("N"); @@ -43,12 +41,92 @@ public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() }; StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"], StandardErrorLines = stdErrCallback }, LoggerFactory) : - new(new() { Command = "ls", Arguments = [id], StandardErrorLines = stdErrCallback }, LoggerFactory); + new(new() { Command = "cmd", Arguments = ["/c", $"echo {id} >&2 & exit /b 1"], StandardErrorLines = stdErrCallback }, LoggerFactory) : + new(new() { Command = "sh", Arguments = ["-c", $"echo {id} >&2; exit 1"], StandardErrorLines = stdErrCallback }, LoggerFactory); await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.InRange(count, 1, int.MaxValue); Assert.Contains(id, sb.ToString()); } + + [Theory] + [InlineData(null)] + [InlineData("argument with spaces")] + [InlineData("&")] + [InlineData("|")] + [InlineData(">")] + [InlineData("<")] + [InlineData("^")] + [InlineData(" & ")] + [InlineData(" | ")] + [InlineData(" > ")] + [InlineData(" < ")] + [InlineData(" ^ ")] + [InlineData("& ")] + [InlineData("| ")] + [InlineData("> ")] + [InlineData("< ")] + [InlineData("^ ")] + [InlineData(" &")] + [InlineData(" |")] + [InlineData(" >")] + [InlineData(" <")] + [InlineData(" ^")] + [InlineData("^&<>|")] + [InlineData("^&<>| ")] + [InlineData(" ^&<>|")] + [InlineData("\t^&<>")] + [InlineData("^&\t<>")] + [InlineData("ls /tmp | grep foo.txt > /dev/null")] + [InlineData("let rec Y f x = f (Y f) x")] + [InlineData("value with \"quotes\" and spaces")] + [InlineData("C:\\Program Files\\Test App\\app.dll")] + [InlineData("C:\\EndsWithBackslash\\")] + [InlineData("--already-looks-like-flag")] + [InlineData("-starts-with-dash")] + [InlineData("name=value=another")] + [InlineData("$(echo injected)")] + [InlineData("value-with-\"quotes\"-and-\\backslashes\\")] + [InlineData("http://localhost:1234/callback?foo=1&bar=2")] + public async Task EscapesCliArgumentsCorrectly(string? cliArgumentValue) + { + if (PlatformDetection.IsMonoRuntime && cliArgumentValue?.EndsWith("\\") is true) + { + Assert.Skip("mono runtime does not handle arguments ending with backslash correctly."); + } + + string cliArgument = $"--cli-arg={cliArgumentValue}"; + + StdioClientTransportOptions options = new() + { + Name = "TestServer", + Command = (PlatformDetection.IsMonoRuntime, PlatformDetection.IsWindows) switch + { + (true, _) => "mono", + (_, true) => "TestServer.exe", + _ => "dotnet", + }, + Arguments = (PlatformDetection.IsMonoRuntime, PlatformDetection.IsWindows) switch + { + (true, _) => ["TestServer.exe", cliArgument], + (_, true) => [cliArgument], + _ => ["TestServer.dll", cliArgument], + }, + }; + + var transport = new StdioClientTransport(options, LoggerFactory); + + // Act: Create client (handshake) and list tools to ensure full round trip works with the argument present. + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(tools); + Assert.NotEmpty(tools); + + var result = await client.CallToolAsync("echoCliArg", cancellationToken: TestContext.Current.CancellationToken); + var content = Assert.IsType(Assert.Single(result.Content)); + Assert.Equal(cliArgumentValue ?? "", content.Text); + } }