Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.modelcontextprotocol.json.TypeRef;

import io.modelcontextprotocol.json.schema.JsonSchemaValidator;
import io.modelcontextprotocol.spec.McpClientSession;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
Expand Down Expand Up @@ -75,6 +77,7 @@
* @author Dariusz Jędrzejczyk
* @author Christian Tzolov
* @author Jihoon Kim
* @author Anurag Pant
* @see McpClient
* @see McpSchema
* @see McpClientSession
Expand Down Expand Up @@ -152,16 +155,33 @@ public class McpAsyncClient {
*/
private final LifecycleInitializer initializer;

/**
* JSON schema validator to use for validating tool responses against output schemas.
*/
private final JsonSchemaValidator jsonSchemaValidator;

/**
* Cached tool output schemas.
*/
private final ConcurrentHashMap<String, Optional<Map<String, Object>>> toolsOutputSchemaCache;

/**
* Whether to enable automatic schema caching during callTool operations.
*/
private final boolean enableCallToolSchemaCaching;

/**
* Create a new McpAsyncClient with the given transport and session request-response
* timeout.
* @param transport the transport to use.
* @param requestTimeout the session request-response timeout.
* @param initializationTimeout the max timeout to await for the client-server
* @param features the MCP Client supported features.
* @param jsonSchemaValidator the JSON schema validator to use for validating tool
* @param features the MCP Client supported features. responses against output
* schemas.
*/
McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout,
McpClientFeatures.Async features) {
JsonSchemaValidator jsonSchemaValidator, McpClientFeatures.Async features) {

Assert.notNull(transport, "Transport must not be null");
Assert.notNull(requestTimeout, "Request timeout must not be null");
Expand All @@ -171,6 +191,9 @@ public class McpAsyncClient {
this.clientCapabilities = features.clientCapabilities();
this.transport = transport;
this.roots = new ConcurrentHashMap<>(features.roots());
this.jsonSchemaValidator = jsonSchemaValidator;
this.toolsOutputSchemaCache = new ConcurrentHashMap<>();
this.enableCallToolSchemaCaching = features.enableCallToolSchemaCaching();

// Request Handlers
Map<String, RequestHandler<?>> requestHandlers = new HashMap<>();
Expand Down Expand Up @@ -539,15 +562,61 @@ private RequestHandler<ElicitResult> elicitationCreateHandler() {
* @see #listTools()
*/
public Mono<McpSchema.CallToolResult> callTool(McpSchema.CallToolRequest callToolRequest) {
return this.initializer.withIntitialization("calling tools", init -> {
if (init.initializeResult().capabilities().tools() == null) {
return Mono.error(new IllegalStateException("Server does not provide tools capability"));
}
return init.mcpSession()
.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF);
return Mono.defer(() -> {
// Conditionally cache schemas if needed, otherwise return empty Mono
Mono<Void> cachingStep = (this.enableCallToolSchemaCaching
&& !this.toolsOutputSchemaCache.containsKey(callToolRequest.name())) ? this.listTools().then()
: Mono.empty();

return cachingStep.then(this.initializer.withIntitialization("calling tool", init -> {
if (init.initializeResult().capabilities().tools() == null) {
return Mono.error(new IllegalStateException("Server does not provide tools capability"));
}

return init.mcpSession()
.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF)
.flatMap(result -> validateToolResult(callToolRequest.name(), result));
}));
});
}

/**
* Calls a tool provided by the server and validates the result against the cached
* output schema.
* @param toolName The name of the tool to call
* @param result The result of the tool call
* @return A Mono that emits the validated tool result
*/
private Mono<McpSchema.CallToolResult> validateToolResult(String toolName, McpSchema.CallToolResult result) {
Optional<Map<String, Object>> optOutputSchema = toolsOutputSchemaCache.get(toolName);

if (result != null && result.isError() != null && !result.isError()) {
if (optOutputSchema == null) {
// Tool not found in cache - skip validation and proceed
logger.debug("Tool '{}' not found in cache, skipping validation", toolName);
return Mono.just(result);
}
else {
if (optOutputSchema.isPresent()) {
// Validate the tool output against the cached output schema
var validation = this.jsonSchemaValidator.validate(optOutputSchema.get(),
result.structuredContent());
if (!validation.valid()) {
logger.warn("Tool call result validation failed: {}", validation.errorMessage());
return Mono.just(new McpSchema.CallToolResult(validation.errorMessage(), true));
}
}
else if (result.structuredContent() != null) {
logger.warn(
"Calling a tool with no outputSchema is not expected to return result with structured content, but got: {}",
result.structuredContent());
}
}
}

return Mono.just(result);
}

/**
* Retrieves the list of all tools provided by the server.
* @return A Mono that emits the list of all tools result
Expand All @@ -574,7 +643,16 @@ public Mono<McpSchema.ListToolsResult> listTools(String cursor) {
}
return init.mcpSession()
.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor),
LIST_TOOLS_RESULT_TYPE_REF);
LIST_TOOLS_RESULT_TYPE_REF)
.map(result -> {
if (result.tools() != null) {
// Cache tools output schema
result.tools()
.forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(),
Optional.ofNullable(tool.outputSchema())));
}
return result;
});
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.function.Function;
import java.util.function.Supplier;

import io.modelcontextprotocol.json.schema.JsonSchemaValidator;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
Expand Down Expand Up @@ -99,6 +100,7 @@
*
* @author Christian Tzolov
* @author Dariusz Jędrzejczyk
* @author Anurag Pant
* @see McpAsyncClient
* @see McpSyncClient
* @see McpTransport
Expand Down Expand Up @@ -187,6 +189,10 @@ class SyncSpec {

private Supplier<McpTransportContext> contextProvider = () -> McpTransportContext.EMPTY;

private JsonSchemaValidator jsonSchemaValidator;

private boolean enableCallToolSchemaCaching = false; // Default to false

private SyncSpec(McpClientTransport transport) {
Assert.notNull(transport, "Transport must not be null");
this.transport = transport;
Expand Down Expand Up @@ -429,6 +435,32 @@ public SyncSpec transportContextProvider(Supplier<McpTransportContext> contextPr
return this;
}

/**
* Add a {@link JsonSchemaValidator} to validate the JSON structure of the
* structured output.
* @param jsonSchemaValidator A validator to validate the JSON structure of the
* structured output. Must not be null.
* @return This builder for method chaining
* @throws IllegalArgumentException if jsonSchemaValidator is null
*/
public SyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) {
Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null");
this.jsonSchemaValidator = jsonSchemaValidator;
return this;
}

/**
* Enables automatic schema caching during callTool operations. When a tool's
* output schema is not found in the cache, callTool will automatically fetch and
* cache all tool schemas via listTools.
* @param enableCallToolSchemaCaching true to enable, false to disable
* @return This builder instance for method chaining
*/
public SyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) {
this.enableCallToolSchemaCaching = enableCallToolSchemaCaching;
return this;
}

/**
* Create an instance of {@link McpSyncClient} with the provided configurations or
* sensible defaults.
Expand All @@ -438,13 +470,13 @@ public McpSyncClient build() {
McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities,
this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler,
this.elicitationHandler);
this.elicitationHandler, this.enableCallToolSchemaCaching);

McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);

return new McpSyncClient(
new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures),
this.contextProvider);
return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout,
jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault(),
asyncFeatures), this.contextProvider);
}

}
Expand Down Expand Up @@ -495,6 +527,10 @@ class AsyncSpec {

private Function<ElicitRequest, Mono<ElicitResult>> elicitationHandler;

private JsonSchemaValidator jsonSchemaValidator;

private boolean enableCallToolSchemaCaching = false; // Default to false

private AsyncSpec(McpClientTransport transport) {
Assert.notNull(transport, "Transport must not be null");
this.transport = transport;
Expand Down Expand Up @@ -741,17 +777,45 @@ public AsyncSpec progressConsumers(
return this;
}

/**
* Sets the JSON schema validator to use for validating tool responses against
* output schemas.
* @param jsonSchemaValidator The validator to use. Must not be null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if jsonSchemaValidator is null
*/
public AsyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) {
Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null");
this.jsonSchemaValidator = jsonSchemaValidator;
return this;
}

/**
* Enables automatic schema caching during callTool operations. When a tool's
* output schema is not found in the cache, callTool will automatically fetch and
* cache all tool schemas via listTools.
* @param enableCallToolSchemaCaching true to enable, false to disable
* @return This builder instance for method chaining
*/
public AsyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) {
this.enableCallToolSchemaCaching = enableCallToolSchemaCaching;
return this;
}

/**
* Create an instance of {@link McpAsyncClient} with the provided configurations
* or sensible defaults.
* @return a new instance of {@link McpAsyncClient}.
*/
public McpAsyncClient build() {
var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator
: JsonSchemaValidator.getDefault();
return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout,
jsonSchemaValidator,
new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers,
this.samplingHandler, this.elicitationHandler));
this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching));
}

}
Expand Down
Loading
Loading