Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions internal/server/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3010,3 +3010,173 @@ func TestE2E_ServerDeleteReaddDifferentTools(t *testing.T) {
t.Log("Phase 3 & 4 Complete: ONLY Tool Set B (new_tool_gamma) searchable and callable")
t.Log("SUCCESS: Stale index entries cleaned up correctly on server re-add")
}

// Test: retrieve_tools returns correct annotations and call_with based on tool hints (Issue #306)
func TestE2E_RetrieveToolsAnnotationsAndCallWith(t *testing.T) {
env := NewTestEnvironment(t)
defer env.Cleanup()

trueVal := true
falseVal := false

// Create mock upstream server with tools that have different annotation hints
mockTools := []mcp.Tool{
{
Name: "delete_records",
Description: "Delete records from the database permanently",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]interface{}{},
},
Annotations: mcp.ToolAnnotation{
DestructiveHint: &trueVal,
ReadOnlyHint: &falseVal,
},
},
{
Name: "update_config",
Description: "Update server configuration settings",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]interface{}{},
},
Annotations: mcp.ToolAnnotation{
ReadOnlyHint: &falseVal,
},
},
{
Name: "list_items",
Description: "List all items in the inventory",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]interface{}{},
},
Annotations: mcp.ToolAnnotation{
ReadOnlyHint: &trueVal,
},
},
}

mockServer := env.CreateMockUpstreamServer("annotated", mockTools)

mcpClient := env.CreateProxyClient()
defer mcpClient.Close()
env.ConnectClient(mcpClient)

ctx := context.Background()

// Add and unquarantine the server
addRequest := mcp.CallToolRequest{}
addRequest.Params.Name = "upstream_servers"
addRequest.Params.Arguments = map[string]interface{}{
"operation": "add",
"name": "annotated",
"url": mockServer.addr,
"protocol": "streamable-http",
"enabled": true,
}

result, err := mcpClient.CallTool(ctx, addRequest)
require.NoError(t, err)
assert.False(t, result.IsError)

serverConfig, err := env.proxyServer.runtime.StorageManager().GetUpstreamServer("annotated")
require.NoError(t, err)
serverConfig.Quarantined = false
err = env.proxyServer.runtime.StorageManager().SaveUpstreamServer(serverConfig)
require.NoError(t, err)

servers, err := env.proxyServer.runtime.StorageManager().ListUpstreamServers()
require.NoError(t, err)

cfg := env.proxyServer.runtime.Config()
cfg.Servers = servers
err = env.proxyServer.runtime.LoadConfiguredServers(cfg)
require.NoError(t, err)

// Wait for connection and trigger tool discovery
time.Sleep(4 * time.Second)
err = env.proxyServer.runtime.DiscoverAndIndexTools(ctx)
require.NoError(t, err)
time.Sleep(2 * time.Second)

// Do a broad search to find all annotated tools
searchRequest := mcp.CallToolRequest{}
searchRequest.Params.Name = "retrieve_tools"
searchRequest.Params.Arguments = map[string]interface{}{
"query": "delete update list records config items",
"limit": 20,
}
searchResult, err := mcpClient.CallTool(ctx, searchRequest)
require.NoError(t, err)
assert.False(t, searchResult.IsError)

require.Greater(t, len(searchResult.Content), 0)
contentBytes, err := json.Marshal(searchResult.Content[0])
require.NoError(t, err)
var contentMap map[string]interface{}
err = json.Unmarshal(contentBytes, &contentMap)
require.NoError(t, err)
contentText, ok := contentMap["text"].(string)
require.True(t, ok)

var searchResponse map[string]interface{}
err = json.Unmarshal([]byte(contentText), &searchResponse)
require.NoError(t, err)

tools, ok := searchResponse["tools"].([]interface{})
require.True(t, ok, "tools should be an array")
t.Logf("Found %d tools in search", len(tools))
for _, toolRaw := range tools {
if tool, ok := toolRaw.(map[string]interface{}); ok {
t.Logf(" Tool: name=%v, server=%v, call_with=%v, annotations=%v",
tool["name"], tool["server"], tool["call_with"], tool["annotations"])
}
}
require.GreaterOrEqual(t, len(tools), 3, "Should find all 3 annotated tools")

// Verify call_with for each tool
// Note: tool names may be bare ("delete_records") or prefixed ("annotated:delete_records")
// depending on how they were indexed. Match by bare name + server field.
expectedCallWith := map[string]string{
"delete_records": "call_tool_destructive",
"update_config": "call_tool_write",
"list_items": "call_tool_read",
}

for _, toolRaw := range tools {
tool, ok := toolRaw.(map[string]interface{})
if !ok {
continue
}
toolName, _ := tool["name"].(string)
serverField, _ := tool["server"].(string)

// Strip server prefix if present for matching
bareName := toolName
if parts := strings.SplitN(toolName, ":", 2); len(parts) == 2 {
bareName = parts[1]
}

expected, isOurs := expectedCallWith[bareName]
if !isOurs || serverField != "annotated" {
continue
}

callWith, ok := tool["call_with"].(string)
assert.True(t, ok, "call_with should be a string for %s", toolName)
assert.Equal(t, expected, callWith,
"call_with mismatch for %s: expected %s, got %s",
toolName, expected, callWith)

// Verify annotations are present for destructive/write tools
if expected != "call_tool_read" {
assert.NotNil(t, tool["annotations"],
"annotations should be present for %s", toolName)
}

delete(expectedCallWith, bareName)
}

assert.Empty(t, expectedCallWith, "Not all expected tools were found: %v", expectedCallWith)
}
23 changes: 17 additions & 6 deletions internal/server/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -847,17 +847,26 @@ func (p *MCPProxyServer) handleRetrieveTools(ctx context.Context, request mcp.Ca
}

// Look up tool annotations and derive recommended call_with variant (Spec 018)
// Parse tool name to get just the tool part (format: server:tool)
parts := strings.SplitN(result.Tool.Name, ":", 2)
if len(parts) == 2 {
annotations := p.lookupToolAnnotations(parts[0], parts[1])
// Use ServerName directly - result.Tool.Name may or may not have "server:" prefix
// depending on how tools were indexed (Issue #306)
serverName := result.Tool.ServerName
toolName := result.Tool.Name
if serverName == "" {
// Fallback: try to extract from "server:tool" format
if parts := strings.SplitN(result.Tool.Name, ":", 2); len(parts) == 2 {
serverName = parts[0]
toolName = parts[1]
}
}

if serverName != "" {
annotations := p.lookupToolAnnotations(serverName, toolName)
if annotations != nil {
mcpTool["annotations"] = annotations
}
// Add call_with recommendation based on annotations
mcpTool["call_with"] = contracts.DeriveCallWith(annotations)
} else {
// Fallback for tools without server prefix (shouldn't happen normally)
mcpTool["call_with"] = contracts.ToolVariantRead // Default to read - safest option
}

Expand Down Expand Up @@ -3984,7 +3993,9 @@ func (p *MCPProxyServer) lookupToolAnnotations(serverName, toolName string) *con
}

for _, tool := range serverStatus.Tools {
if tool.Name == toolName {
// tool.Name may be in "server:tool" format (from ToolMetadata.Name),
// while toolName is just the tool part. Match both formats.
if tool.Name == toolName || tool.Name == serverName+":"+toolName {
return tool.Annotations
}
}
Expand Down
158 changes: 158 additions & 0 deletions internal/server/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"go.uber.org/zap"

"github.com/smart-mcp-proxy/mcpproxy-go/internal/config"
"github.com/smart-mcp-proxy/mcpproxy-go/internal/contracts"
"github.com/smart-mcp-proxy/mcpproxy-go/internal/secret"
"github.com/smart-mcp-proxy/mcpproxy-go/internal/upstream"
)
Expand Down Expand Up @@ -307,6 +308,163 @@ func TestToolFormatConversion(t *testing.T) {
assert.Contains(t, properties, "market_data")
}

// TestAnnotationLookupNameMatching tests that lookupToolAnnotations correctly
// matches tool names regardless of whether StateView stores them as "tool"
// or "server:tool" format. This is the bug reported in Issue #306.
func TestAnnotationLookupNameMatching(t *testing.T) {
trueVal := true
falseVal := false

tests := []struct {
name string
serverName string
toolName string
stateViewName string // How the tool name is stored in StateView
annotations *config.ToolAnnotations
expectedCallWith string
}{
{
name: "StateView stores full name (server:tool), lookup uses bare tool name",
serverName: "github",
toolName: "delete_repo",
stateViewName: "github:delete_repo",
annotations: &config.ToolAnnotations{
DestructiveHint: &trueVal,
},
expectedCallWith: "call_tool_destructive",
},
{
name: "StateView stores bare tool name, lookup uses bare tool name",
serverName: "github",
toolName: "delete_repo",
stateViewName: "delete_repo",
annotations: &config.ToolAnnotations{
DestructiveHint: &trueVal,
},
expectedCallWith: "call_tool_destructive",
},
{
name: "write tool with full name in StateView",
serverName: "myserver",
toolName: "update_config",
stateViewName: "myserver:update_config",
annotations: &config.ToolAnnotations{
ReadOnlyHint: &falseVal,
},
expectedCallWith: "call_tool_write",
},
{
name: "read-only tool with full name in StateView",
serverName: "myserver",
toolName: "list_items",
stateViewName: "myserver:list_items",
annotations: &config.ToolAnnotations{
ReadOnlyHint: &trueVal,
},
expectedCallWith: "call_tool_read",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the name matching logic from lookupToolAnnotations
// This tests the fix for Issue #306 where tool.Name in StateView
// is "server:tool" but toolName passed in is just "tool"
matched := false
toolNameInStateView := tt.stateViewName
if toolNameInStateView == tt.toolName || toolNameInStateView == tt.serverName+":"+tt.toolName {
matched = true
}

assert.True(t, matched, "Tool name matching failed: stateview=%q, lookup=%q",
tt.stateViewName, tt.toolName)

// Verify DeriveCallWith returns correct variant when annotations are found
if matched {
callWith := contracts.DeriveCallWith(tt.annotations)
assert.Equal(t, tt.expectedCallWith, callWith)
}
})
}
}

// TestRetrieveToolsCallWithAnnotations verifies that the handleRetrieveTools
// code path correctly splits tool names and derives call_with from annotations.
// This is a regression test for Issue #306.
func TestRetrieveToolsCallWithAnnotations(t *testing.T) {
trueVal := true
falseVal := false

// Simulate search results as returned by the index
mockResults := []*config.SearchResult{
{
Tool: &config.ToolMetadata{
Name: "myserver:delete_data",
ServerName: "myserver",
Annotations: &config.ToolAnnotations{
DestructiveHint: &trueVal,
},
},
Score: 0.9,
},
{
Tool: &config.ToolMetadata{
Name: "myserver:update_config",
ServerName: "myserver",
Annotations: &config.ToolAnnotations{
ReadOnlyHint: &falseVal,
},
},
Score: 0.8,
},
{
Tool: &config.ToolMetadata{
Name: "myserver:list_items",
ServerName: "myserver",
Annotations: &config.ToolAnnotations{
ReadOnlyHint: &trueVal,
},
},
Score: 0.7,
},
{
Tool: &config.ToolMetadata{
Name: "myserver:unknown_tool",
ServerName: "myserver",
// No annotations
},
Score: 0.6,
},
}

// Simulate the annotation lookup + call_with derivation from handleRetrieveTools
// In production, lookupToolAnnotations queries the StateView, but we can test
// the name-splitting logic and DeriveCallWith here.
for _, result := range mockResults {
parts := strings.SplitN(result.Tool.Name, ":", 2)
require.Len(t, parts, 2, "Tool name should be in server:tool format: %s", result.Tool.Name)

// The fix for #306: even if lookupToolAnnotations can't find annotations
// via StateView (because of name mismatch), we can verify the split is correct
assert.Equal(t, result.Tool.ServerName, parts[0],
"Server name from split should match ServerName field")

// Verify DeriveCallWith with the tool's own annotations
callWith := contracts.DeriveCallWith(result.Tool.Annotations)

switch result.Tool.Name {
case "myserver:delete_data":
assert.Equal(t, "call_tool_destructive", callWith)
case "myserver:update_config":
assert.Equal(t, "call_tool_write", callWith)
case "myserver:list_items":
assert.Equal(t, "call_tool_read", callWith)
case "myserver:unknown_tool":
assert.Equal(t, "call_tool_read", callWith) // nil annotations → safe default
}
}
}

func TestUpstreamServerOperations(t *testing.T) {
// Test basic server operations parsing
t.Run("BasicServerOperations", func(t *testing.T) {
Expand Down