Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,27 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-anthropic</artifactId>
<version>${project.parent.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-anthropic</artifactId>
<version>${project.parent.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
<version>${project.parent.version}</version>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.server.McpSyncServer;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider;
Expand All @@ -54,6 +53,7 @@
import net.javacrumbs.jsonunit.assertj.JsonAssertions;
import net.javacrumbs.jsonunit.core.Option;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springaicommunity.mcp.annotation.McpArg;
Expand All @@ -68,21 +68,23 @@
import org.springaicommunity.mcp.annotation.McpSampling;
import org.springaicommunity.mcp.annotation.McpTool;
import org.springaicommunity.mcp.annotation.McpToolParam;
import org.springaicommunity.mcp.method.elicitation.SyncElicitationSpecification;
import org.springaicommunity.mcp.method.logging.SyncLoggingSpecification;
import org.springaicommunity.mcp.method.progress.SyncProgressSpecification;
import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification;
import reactor.netty.DisposableServer;
import reactor.netty.http.server.HttpServer;

import org.springframework.ai.mcp.annotation.spring.SyncMcpAnnotationProviders;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration;
import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration;
import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration;
import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration;
import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration;
import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration;
import org.springframework.ai.mcp.server.common.autoconfigure.ToolCallbackConverterAutoConfiguration;
import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerAnnotationScannerAutoConfiguration;
import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerSpecificationFactoryAutoConfiguration;
import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerProperties;
import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerStreamableHttpProperties;
import org.springframework.ai.model.anthropic.autoconfigure.AnthropicChatAutoConfiguration;
import org.springframework.ai.model.chat.client.autoconfigure.ChatClientAutoConfiguration;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
Expand All @@ -98,16 +100,22 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.InstanceOfAssertFactories.map;

@EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+")
public class StreamableMcpAnnotationsManualIT {

private final ApplicationContextRunner serverContextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.mcp.server.protocol=STREAMABLE")
.withConfiguration(AutoConfigurations.of(McpServerAutoConfiguration.class,
.withConfiguration(AutoConfigurations.of(McpServerAnnotationScannerAutoConfiguration.class,
McpServerSpecificationFactoryAutoConfiguration.class, McpServerAutoConfiguration.class,
ToolCallbackConverterAutoConfiguration.class, McpServerStreamableHttpWebFluxAutoConfiguration.class));

private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class,
McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class));
McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class,
// MCP Annotations
McpClientAnnotationScannerAutoConfiguration.class, McpClientSpecificationFactoryAutoConfiguration.class,
// Anthropic ChatClient Builder
AnthropicChatAutoConfiguration.class, ChatClientAutoConfiguration.class));

@Test
void clientServerCapabilities() {
Expand Down Expand Up @@ -141,6 +149,7 @@ void clientServerCapabilities() {

this.clientApplicationContext.withUserConfiguration(TestMcpClientConfiguration.class)
.withPropertyValues(// @formatter:off
"spring.ai.anthropic.api-key=" + System.getenv("ANTHROPIC_API_KEY"),
"spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:" + serverPort,
// "spring.ai.mcp.client.request-timeout=20m",
"spring.ai.mcp.client.initialized=false") // @formatter:on
Expand Down Expand Up @@ -306,28 +315,6 @@ public McpServerHandlers serverSideSpecProviders() {
return new McpServerHandlers();
}

@Bean
public List<McpServerFeatures.SyncToolSpecification> myTools(McpServerHandlers serverSideSpecProviders) {
return SyncMcpAnnotationProviders.toolSpecifications(List.of(serverSideSpecProviders));
}

@Bean
public List<McpServerFeatures.SyncResourceSpecification> myResources(
McpServerHandlers serverSideSpecProviders) {
return SyncMcpAnnotationProviders.resourceSpecifications(List.of(serverSideSpecProviders));
}

@Bean
public List<McpServerFeatures.SyncPromptSpecification> myPrompts(McpServerHandlers serverSideSpecProviders) {
return SyncMcpAnnotationProviders.promptSpecifications(List.of(serverSideSpecProviders));
}

@Bean
public List<McpServerFeatures.SyncCompletionSpecification> myCompletions(
McpServerHandlers serverSideSpecProviders) {
return SyncMcpAnnotationProviders.completeSpecifications(List.of(serverSideSpecProviders));
}

public static class McpServerHandlers {

@McpTool(description = "Test tool", name = "tool1")
Expand Down Expand Up @@ -449,28 +436,9 @@ public TestContext testContext() {
}

@Bean
public McpClientHandlers mcpClientHandlers(TestContext testContext) {
return new McpClientHandlers(testContext);
}

@Bean
List<SyncLoggingSpecification> loggingSpecs(McpClientHandlers clientMcpHandlers) {
return SyncMcpAnnotationProviders.loggingSpecifications(List.of(clientMcpHandlers));
}

@Bean
List<SyncSamplingSpecification> samplingSpecs(McpClientHandlers clientMcpHandlers) {
return SyncMcpAnnotationProviders.samplingSpecifications(List.of(clientMcpHandlers));
}

@Bean
List<SyncElicitationSpecification> elicitationSpecs(McpClientHandlers clientMcpHandlers) {
return SyncMcpAnnotationProviders.elicitationSpecifications(List.of(clientMcpHandlers));
}

@Bean
List<SyncProgressSpecification> progressSpecs(McpClientHandlers clientMcpHandlers) {
return SyncMcpAnnotationProviders.progressSpecifications(List.of(clientMcpHandlers));
public McpClientHandlers mcpClientHandlers(TestContext testContext,
ObjectProvider<ChatClient.Builder> chatClientBuilderProvider) {
return new McpClientHandlers(testContext, chatClientBuilderProvider);
}

public static class TestContext {
Expand All @@ -489,8 +457,21 @@ public static class McpClientHandlers {

private TestContext testContext;

public McpClientHandlers(TestContext testContext) {
private final ObjectProvider<ChatClient.Builder> chatClientBuilderProvider;

private AtomicReference<ChatClient> chatClientRef = new AtomicReference<>();

private ChatClient chatClient() {
if (this.chatClientRef.get() == null) {
this.chatClientRef.compareAndSet(null, this.chatClientBuilderProvider.getIfAvailable().build());
}
return this.chatClientRef.get();
}

public McpClientHandlers(TestContext testContext,
ObjectProvider<ChatClient.Builder> chatClientBuilderProvider) {
this.testContext = testContext;
this.chatClientBuilderProvider = chatClientBuilderProvider;
}

@McpProgress(clients = "server1")
Expand All @@ -515,6 +496,11 @@ public CreateMessageResult samplingHandler(CreateMessageRequest llmRequest) {
String userPrompt = ((McpSchema.TextContent) llmRequest.messages().get(0).content()).text();
String modelHint = llmRequest.modelPreferences().hints().get(0).name();

// String joke =
// this.chatClientBuilderProvider.getIfAvailable().build().prompt("Tell me
// a joke").call().content();
String joke = this.chatClient().prompt("Tell me a joke").call().content();
logger.info("Received joke from chat client: {}", joke);
return CreateMessageResult.builder()
.content(new McpSchema.TextContent("Response " + userPrompt + " with model hint " + modelHint))
.build();
Expand Down