diff --git a/samples/anthropic/tools/ToolsConsole/McpToolExtensions.cs b/samples/anthropic/tools/ToolsConsole/McpToolExtensions.cs index 486bb62f..6b534f9b 100644 --- a/samples/anthropic/tools/ToolsConsole/McpToolExtensions.cs +++ b/samples/anthropic/tools/ToolsConsole/McpToolExtensions.cs @@ -16,9 +16,7 @@ public static class McpToolExtensions List result = []; foreach (var tool in tools) { - var function = tool.InputSchema == null - ? new Function(tool.Name, tool.Description) - : new Function(tool.Name, tool.Description, JsonSerializer.Serialize(tool.InputSchema)); + var function = new Function(tool.Name, tool.Description, JsonSerializer.SerializeToNode(tool.InputSchema)); result.Add(function); } return result; diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 5251f7b5..5d18d0b6 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -466,8 +466,6 @@ private static JsonRpcRequest CreateRequest(string method, DictionaryProvides an AI function that calls a tool through . private sealed class McpAIFunction(IMcpClient client, Tool tool) : AIFunction { - private JsonElement? _jsonSchema; - /// public override string Name => tool.Name; @@ -475,15 +473,7 @@ private sealed class McpAIFunction(IMcpClient client, Tool tool) : AIFunction public override string Description => tool.Description ?? string.Empty; /// - public override JsonElement JsonSchema => _jsonSchema ??= - JsonSerializer.SerializeToElement(new Dictionary - { - ["type"] = "object", - ["title"] = tool.Name, - ["description"] = tool.Description ?? string.Empty, - ["properties"] = tool.InputSchema?.Properties ?? [], - ["required"] = tool.InputSchema?.Required ?? [] - }, McpJsonUtilities.JsonContext.Default.DictionaryStringObject); + public override JsonElement JsonSchema => tool.InputSchema; /// protected async override Task InvokeCoreAsync( diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs index af1de386..544c8df5 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs @@ -96,7 +96,7 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params { Name = function.Name, Description = function.Description, - InputSchema = JsonSerializer.Deserialize(function.JsonSchema, McpJsonUtilities.JsonContext.Default.JsonSchema), + InputSchema = function.JsonSchema, }); callbacks.Add(function.Name, async (request, cancellationToken) => diff --git a/src/ModelContextProtocol/Protocol/Types/JsonSchema.cs b/src/ModelContextProtocol/Protocol/Types/JsonSchema.cs deleted file mode 100644 index 384dc977..00000000 --- a/src/ModelContextProtocol/Protocol/Types/JsonSchema.cs +++ /dev/null @@ -1,26 +0,0 @@ -namespace ModelContextProtocol.Protocol.Types; - -/// -/// Represents a JSON schema for a tool's input arguments. -/// See the schema for details -/// -public class JsonSchema -{ - /// - /// The type of the schema, should be "object". - /// - [System.Text.Json.Serialization.JsonPropertyName("type")] - public string Type { get; set; } = "object"; - - /// - /// Map of property names to property definitions. - /// - [System.Text.Json.Serialization.JsonPropertyName("properties")] - public Dictionary? Properties { get; set; } - - /// - /// List of required property names. - /// - [System.Text.Json.Serialization.JsonPropertyName("required")] - public List? Required { get; set; } -} diff --git a/src/ModelContextProtocol/Protocol/Types/JsonSchemaProperty.cs b/src/ModelContextProtocol/Protocol/Types/JsonSchemaProperty.cs deleted file mode 100644 index 8fdd8f30..00000000 --- a/src/ModelContextProtocol/Protocol/Types/JsonSchemaProperty.cs +++ /dev/null @@ -1,20 +0,0 @@ -namespace ModelContextProtocol.Protocol.Types; - -/// -/// Represents a property in a JSON schema. -/// See the schema for details -/// -public class JsonSchemaProperty -{ - /// - /// The type of the property. Should be a JSON Schema type and is required. - /// - [System.Text.Json.Serialization.JsonPropertyName("type")] - public string Type { get; set; } = string.Empty; - - /// - /// A human-readable description of the property. - /// - [System.Text.Json.Serialization.JsonPropertyName("description")] - public string? Description { get; set; } = string.Empty; -} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/Tool.cs b/src/ModelContextProtocol/Protocol/Types/Tool.cs index 87c0f0a0..a36cef2e 100644 --- a/src/ModelContextProtocol/Protocol/Types/Tool.cs +++ b/src/ModelContextProtocol/Protocol/Types/Tool.cs @@ -1,4 +1,6 @@ -using System.Text.Json.Serialization; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; @@ -23,6 +25,23 @@ public class Tool /// /// A JSON Schema object defining the expected parameters for the tool. /// + /// + /// Needs to a valid JSON schema object that additionally is of type object. + /// [JsonPropertyName("inputSchema")] - public JsonSchema? InputSchema { get; set; } + public JsonElement InputSchema + { + get => _inputSchema; + set + { + if (!McpJsonUtilities.IsValidMcpToolSchema(value)) + { + throw new ArgumentException("The specified document is not a valid MPC tool JSON schema.", nameof(InputSchema)); + } + + _inputSchema = value; + } + } + + private JsonElement _inputSchema = McpJsonUtilities.DefaultMcpToolSchema; } diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index 3c4837ee..7a868a6d 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -77,6 +77,31 @@ private static JsonSerializerOptions CreateDefaultOptions() internal static JsonTypeInfo GetTypeInfo(this JsonSerializerOptions options) => (JsonTypeInfo)options.GetTypeInfo(typeof(T)); + internal static JsonElement DefaultMcpToolSchema = ParseJsonElement("{\"type\":\"object\"}"u8); + internal static bool IsValidMcpToolSchema(JsonElement element) + { + if (element.ValueKind is not JsonValueKind.Object) + { + return false; + } + + foreach (JsonProperty property in element.EnumerateObject()) + { + if (property.NameEquals("type")) + { + if (property.Value.ValueKind is not JsonValueKind.String || + !property.Value.ValueEquals("object")) + { + return false; + } + + return true; // No need to check other properties + } + } + + return false; // No type keyword found. + } + // Keep in sync with CreateDefaultOptions above. [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, UseStringEnumConverter = true, @@ -96,7 +121,13 @@ internal static JsonTypeInfo GetTypeInfo(this JsonSerializerOptions option [JsonSerializable(typeof(CreateMessageResult))] [JsonSerializable(typeof(ListRootsResult))] [JsonSerializable(typeof(InitializeResult))] - [JsonSerializable(typeof(JsonSchema))] [JsonSerializable(typeof(CallToolResponse))] internal sealed partial class JsonContext : JsonSerializerContext; + + private static JsonElement ParseJsonElement(ReadOnlySpan utf8Json) + { + Utf8JsonReader reader = new(utf8Json); + return JsonElement.ParseValue(ref reader); + } + } diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index ab28ef3a..5cf30781 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -1,10 +1,11 @@ -using System.Text; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using Microsoft.Extensions.Logging; using Serilog; +using System.Text; +using System.Text.Json; namespace ModelContextProtocol.TestServer; @@ -91,28 +92,39 @@ private static ToolsCapability ConfigureTools() { Name = "echo", Description = "Echoes the input back to the client.", - InputSchema = new JsonSchema() - { - Type = "object", - Properties = new Dictionary() + InputSchema = JsonSerializer.Deserialize(""" { - ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The input to echo back." + } + }, + "required": ["message"] } - }, + """), }, new Tool() { Name = "sampleLLM", Description = "Samples from an LLM using MCP's sampling feature.", - InputSchema = new JsonSchema() - { - Type = "object", - Properties = new Dictionary() + InputSchema = JsonSerializer.Deserialize(""" { - ["prompt"] = new JsonSchemaProperty() { Type = "string", Description = "The prompt to send to the LLM" }, - ["maxTokens"] = new JsonSchemaProperty() { Type = "number", Description = "Maximum number of tokens to generate" } + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The prompt to send to the LLM" + }, + "maxTokens": { + "type": "number", + "description": "Maximum number of tokens to generate" + } + }, + "required": ["prompt", "maxTokens"] } - }, + """), } ] }); diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index be06f162..ae58174f 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -4,6 +4,7 @@ using Microsoft.Extensions.Logging; using Serilog; using System.Text; +using System.Text.Json; internal class Program { @@ -121,28 +122,39 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { Name = "echo", Description = "Echoes the input back to the client.", - InputSchema = new JsonSchema() - { - Type = "object", - Properties = new Dictionary() + InputSchema = JsonSerializer.Deserialize(""" { - ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The input to echo back." + } + }, + "required": ["message"] } - }, + """), }, new Tool() { Name = "sampleLLM", Description = "Samples from an LLM using MCP's sampling feature.", - InputSchema = new JsonSchema() - { - Type = "object", - Properties = new Dictionary() + InputSchema = JsonSerializer.Deserialize(""" { - ["prompt"] = new JsonSchemaProperty() { Type = "string", Description = "The prompt to send to the LLM" }, - ["maxTokens"] = new JsonSchemaProperty() { Type = "number", Description = "Maximum number of tokens to generate" } + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The prompt to send to the LLM" + }, + "maxTokens": { + "type": "number", + "description": "Maximum number of tokens to generate" + } + }, + "required": ["prompt", "maxTokens"] } - }, + """), } ] }); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 349c91cb..42c11e47 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -48,16 +48,10 @@ public async Task Can_List_Registered_Tool() var tool = result.Tools[0]; Assert.Equal("Echo", tool.Name); Assert.Equal("Echoes the input back to the client.", tool.Description); - Assert.NotNull(tool.InputSchema); - Assert.Equal("object", tool.InputSchema.Type); - Assert.NotNull(tool.InputSchema.Properties); - Assert.NotEmpty(tool.InputSchema.Properties); - Assert.Contains("message", tool.InputSchema.Properties); - Assert.Equal("string", tool.InputSchema.Properties["message"].Type); - Assert.Equal("the echoes message", tool.InputSchema.Properties["message"].Description); - Assert.NotNull(tool.InputSchema.Required); - Assert.NotEmpty(tool.InputSchema.Required); - Assert.Contains("message", tool.InputSchema.Required); + Assert.Equal("object", tool.InputSchema.GetProperty("type").GetString()); + Assert.Equal(JsonValueKind.Object, tool.InputSchema.GetProperty("properties").GetProperty("message").ValueKind); + Assert.Equal("the echoes message", tool.InputSchema.GetProperty("properties").GetProperty("message").GetProperty("description").GetString()); + Assert.Equal(1, tool.InputSchema.GetProperty("required").GetArrayLength()); tool = result.Tools[1]; Assert.Equal("double_echo", tool.Name); @@ -288,31 +282,15 @@ public async Task Recognizes_Parameter_Types() var tool = result.Tools.First(t => t.Name == "TestTool"); Assert.Equal("TestTool", tool.Name); Assert.Empty(tool.Description!); - Assert.NotNull(tool.InputSchema); - Assert.Equal("object", tool.InputSchema.Type); - Assert.NotNull(tool.InputSchema.Properties); - Assert.NotEmpty(tool.InputSchema.Properties); - - Assert.Contains("number", tool.InputSchema.Properties); - Assert.Equal("integer", tool.InputSchema.Properties["number"].Type); - - Assert.Contains("otherNumber", tool.InputSchema.Properties); - Assert.Equal("number", tool.InputSchema.Properties["otherNumber"].Type); - - Assert.Contains("someCheck", tool.InputSchema.Properties); - Assert.Equal("boolean", tool.InputSchema.Properties["someCheck"].Type); - - Assert.Contains("someDate", tool.InputSchema.Properties); - Assert.Equal("string", tool.InputSchema.Properties["someDate"].Type); - - Assert.Contains("someOtherDate", tool.InputSchema.Properties); - Assert.Equal("string", tool.InputSchema.Properties["someOtherDate"].Type); - - Assert.Contains("data", tool.InputSchema.Properties); - Assert.Equal("array", tool.InputSchema.Properties["data"].Type); - - Assert.Contains("complexObject", tool.InputSchema.Properties); - Assert.Equal("object", tool.InputSchema.Properties["complexObject"].Type); + Assert.Equal("object", tool.InputSchema.GetProperty("type").GetString()); + + Assert.Contains("integer", tool.InputSchema.GetProperty("properties").GetProperty("number").GetProperty("type").GetString()); + Assert.Contains("number", tool.InputSchema.GetProperty("properties").GetProperty("otherNumber").GetProperty("type").GetString()); + Assert.Contains("boolean", tool.InputSchema.GetProperty("properties").GetProperty("someCheck").GetProperty("type").GetString()); + Assert.Contains("string", tool.InputSchema.GetProperty("properties").GetProperty("someDate").GetProperty("type").GetString()); + Assert.Contains("string", tool.InputSchema.GetProperty("properties").GetProperty("someOtherDate").GetProperty("type").GetString()); + Assert.Contains("array", tool.InputSchema.GetProperty("properties").GetProperty("data").GetProperty("type").GetString()); + Assert.Contains("object", tool.InputSchema.GetProperty("properties").GetProperty("complexObject").GetProperty("type").GetString()); } [McpToolType] diff --git a/tests/ModelContextProtocol.Tests/Protocol/ProtocolTypeTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ProtocolTypeTests.cs new file mode 100644 index 00000000..25cda950 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/ProtocolTypeTests.cs @@ -0,0 +1,52 @@ +using ModelContextProtocol.Protocol.Types; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Protocol; + +public static class ProtocolTypeTests +{ + [Fact] + public static void ToolInputSchema_HasValidDefaultSchema() + { + var tool = new Tool(); + JsonElement jsonElement = tool.InputSchema; + + Assert.Equal(JsonValueKind.Object, jsonElement.ValueKind); + Assert.Single(jsonElement.EnumerateObject()); + Assert.True(jsonElement.TryGetProperty("type", out JsonElement typeElement)); + Assert.Equal(JsonValueKind.String, typeElement.ValueKind); + Assert.Equal("object", typeElement.GetString()); + } + + [Theory] + [InlineData("null")] + [InlineData("false")] + [InlineData("true")] + [InlineData("3.5e3")] + [InlineData("[]")] + [InlineData("{}")] + [InlineData("""{"properties":{}}""")] + [InlineData("""{"type":"number"}""")] + [InlineData("""{"type":"array"}""")] + [InlineData("""{"type":["object"]}""")] + public static void ToolInputSchema_RejectsInvalidSchemaDocuments(string invalidSchema) + { + using var document = JsonDocument.Parse(invalidSchema); + var tool = new Tool(); + + Assert.Throws(() => tool.InputSchema = document.RootElement); + } + + [Theory] + [InlineData("""{"type":"object"}""")] + [InlineData("""{"type":"object", "properties": {}, "required" : [] }""")] + [InlineData("""{"type":"object", "title": "MyAwesomeTool", "description": "It's awesome!", "properties": {}, "required" : ["NotAParam"] }""")] + public static void ToolInputSchema_AcceptsValidSchemaDocuments(string validSchema) + { + using var document = JsonDocument.Parse(validSchema); + var tool = new Tool(); + + tool.InputSchema = document.RootElement; + Assert.True(JsonElement.DeepEquals(document.RootElement, tool.InputSchema)); + } +} diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index 0f23c8fe..44294009 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -1,6 +1,7 @@ using System.Threading.Channels; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; namespace ModelContextProtocol.Tests.Utils; @@ -71,7 +72,7 @@ private async Task Sampling(JsonRpcRequest request, CancellationToken cancellati await WriteMessageAsync(new JsonRpcResponse { Id = request.Id, - Result = new Protocol.Types.CreateMessageResult { Content = new(), Model = "model", Role = "role" } + Result = new CreateMessageResult { Content = new(), Model = "model", Role = "role" } }, cancellationToken); }