Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>cn.bigmodel.openapi</groupId>
<artifactId>oapi-java-sdk</artifactId>
<version>release-V4-2.4.4</version>
<version>release-V4-2.4.5</version>

<packaging>jar</packaging>

Expand Down
7 changes: 7 additions & 0 deletions src/main/java/com/zhipu/oapi/service/v4/model/ChatTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public class ChatTool extends ObjectNode {
@JsonProperty("web_search")
private WebSearch web_search;

private MCPTool mcp;

public ChatTool(){
super(JsonNodeFactory.instance);
}
Expand All @@ -48,4 +50,9 @@ public void setWeb_search(WebSearch web_search) {
this.web_search = web_search;
this.putPOJO("web_search",web_search);
}

public void setMcp(MCPTool mcp) {
this.mcp = mcp;
this.putPOJO("mcp", mcp);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ public enum ChatToolType {

RETRIEVAL("retrieval"),

FUNCTION("function");
FUNCTION("function"),

MCP("mcp"),
;

private final String value;

Expand Down
75 changes: 75 additions & 0 deletions src/main/java/com/zhipu/oapi/service/v4/model/MCPTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.zhipu.oapi.service.v4.model;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;

import java.io.Serializable;
import java.util.Map;
import java.util.Set;

@Getter
public class MCPTool extends ObjectNode {
/**
* mcp server 的标识,用于区分不同的 mcp server,必填
*/
private String server_label;

/**
* mcp server 的 url,非必填
* 默认(若该字段为空):以 server_label 作为 mcpCode,连接智谱AI的 mcp servers,
*/
private String server_url;

/**
* mcp 调用的传输方式:sse/streamable-http,默认为 streamable-http
*/
private String transport_type;

/**
* 允许调用的工具列表,默认为空,即允许所有工具
*/
private Set<String> allowed_tools;

/**
* 连接 mcp server 的 headers,鉴权使用
*/
private Map<String, String> headers;

public MCPTool() {
super(JsonNodeFactory.instance);
}

public MCPTool(JsonNodeFactory nc, Map<String, JsonNode> kids) {
super(nc, kids);
}

public void setServer_label(String server_label) {
this.server_label = server_label;
this.put("server_label", server_label);
}

public void setServer_url(String server_url) {
this.server_url = server_url;
this.put("server_url", server_url);
}

public void setTransport_type(String transport_type) {
this.transport_type = transport_type;
this.put("transport_type", transport_type);
}

public void setAllowed_tools(Set<String> allowed_tools) {
this.allowed_tools = allowed_tools;
this.putPOJO("allowed_tools", allowed_tools);
}

public void setHeaders(Map<String, String> headers) {
this.headers = headers;
this.putPOJO("headers", headers);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.zhipu.oapi.service.v4.model;

import lombok.AllArgsConstructor;
import lombok.Getter;

@AllArgsConstructor
@Getter
public enum McpToolTransportType {

SSE("sse", "SSE"),
STREAMABLE_HTTP("streamable-http", "可流式传输的HTTP");

private final String code;
private final String value;

}
156 changes: 156 additions & 0 deletions src/test/java/com/zhipu/oapi/McpTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package com.zhipu.oapi;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zhipu.oapi.service.v4.model.*;
import com.zhipu.oapi.utils.StringUtils;
import io.reactivex.Flowable;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testcontainers.junit.jupiter.Testcontainers;

import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

@Testcontainers
public class McpTest {

private final static Logger logger = LoggerFactory.getLogger(TestAssistantClientApiService.class);
private static final String ZHIPUAI_API_KEY = getTestApiKey();

private static ClientV4 client = null;

private static final String requestIdTemplate = "mycompany-%d";

private static final ObjectMapper mapper = new ObjectMapper();

static {
client = new ClientV4.Builder(ZHIPUAI_API_KEY)
.enableTokenCache()
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
.build();
}

private static String getTestApiKey() {
String apiKey = Constants.getApiKey();
return apiKey != null ? apiKey : "test-api-key.test-api-secret";
}

@Test
void testMcpTool_ServerUrl_SSE() throws JsonProcessingException {
// MCP 参数构建部分
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", "Bearer" + ZHIPUAI_API_KEY);
MCPTool mcpTool = new MCPTool();
mcpTool.setServer_label("sougou_search");
mcpTool.setServer_url("https://open.bigmodel.cn/api/mcp/sogou/sse");
mcpTool.setTransport_type(McpToolTransportType.SSE.getCode());
mcpTool.setHeaders(headers);

ChatTool chatTool = new ChatTool();
chatTool.setType(ChatToolType.MCP.value());
chatTool.setMcp(mcpTool);

List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "今天是几月几号?");

messages.add(chatMessage);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(Constants.ModelChatGLM4)
.stream(Boolean.FALSE)
.messages(messages)
.requestId(requestId)
.invokeMethod(Constants.invokeMethod)
.tools(Collections.singletonList(chatTool))
.build();
ModelApiResponse modelApiResp = client.invokeModelApi(chatCompletionRequest);
logger.info("model output: {}", mapper.writeValueAsString(modelApiResp));
}

@Test
void testMcpTool_ServerLabel() throws JsonProcessingException {
// MCP 参数构建部分
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", "Bearer " + ZHIPUAI_API_KEY);
MCPTool mcpTool = new MCPTool();
mcpTool.setServer_label("aviation");
mcpTool.setHeaders(headers);

ChatTool chatTool = new ChatTool();
chatTool.setType(ChatToolType.MCP.value());
chatTool.setMcp(mcpTool);

List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "北京现在天气怎么样?");

messages.add(chatMessage);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(Constants.ModelChatGLM4)
.stream(Boolean.TRUE)
.messages(messages)
.requestId(requestId)
.invokeMethod(Constants.invokeMethod)
.tools(Collections.singletonList(chatTool))
.build();
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
if (sseModelApiResp.isSuccess()) {
AtomicBoolean isFirst = new AtomicBoolean(true);
List<Choice> choices = new ArrayList<>();
AtomicReference<ChatMessageAccumulator> lastAccumulator = new AtomicReference<>();

mapStreamToAccumulator(sseModelApiResp.getFlowable())
.doOnNext(accumulator -> {
{
if (isFirst.getAndSet(false)) {
logger.info("Response: ");
}
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
logger.info("tool_calls: {}", jsonString);
}
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
logger.info(accumulator.getDelta().getContent());
}
choices.add(accumulator.getChoice());
lastAccumulator.set(accumulator);

}
})
.doOnComplete(() -> System.out.println("Stream completed."))
.doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors
.blockingSubscribe();// Use blockingSubscribe instead of blockingGet()

ChatMessageAccumulator chatMessageAccumulator = lastAccumulator.get();
ModelData data = new ModelData();
data.setChoices(choices);
if (chatMessageAccumulator != null) {
data.setUsage(chatMessageAccumulator.getUsage());
data.setId(chatMessageAccumulator.getId());
data.setCreated(chatMessageAccumulator.getCreated());
}
data.setRequestId(chatCompletionRequest.getRequestId());
sseModelApiResp.setFlowable(null);// 打印前置空
sseModelApiResp.setData(data);
}
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
client.getConfig().getHttpClient().dispatcher().executorService().shutdown();

client.getConfig().getHttpClient().connectionPool().evictAll();
// List all active threads
for (Thread t : Thread.getAllStackTraces().keySet()) {
logger.info("Thread: " + t.getName() + " State: " + t.getState());
}
}

public static Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ModelData> flowable) {
return flowable.map(chunk -> {
return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
});
}
}
Loading