diff --git a/docs/semantic-tool-filtering/v0.1/docs/semantic-tool-filtering.md b/docs/semantic-tool-filtering/v0.1/docs/semantic-tool-filtering.md new file mode 100644 index 00000000..70552d3d --- /dev/null +++ b/docs/semantic-tool-filtering/v0.1/docs/semantic-tool-filtering.md @@ -0,0 +1,178 @@ +--- +title: "Overview" +--- + +# Semantic Tool Filtering policy + +## Overview +The **Semantic Tool Filtering** policy dynamically filters the tools provided within an API request based on their semantic relevance to the user query. This policy extracts both the query and the tool definitions from the incoming payload, generates embeddings for the query, and performs a similarity search against the provided tools. It then replaces the original tools array with a filtered subset, optimizing the request before it reaches the LLM. + +This policy helps reduce token consumption and improve LLM response quality by sending only the most relevant tools for each request. + +## Features +- **Semantic similarity-based filtering** of tools using embedding vectors. +- **Two selection modes**: "By Rank" (top-K) and "By Threshold". +- **Flexible Format Support**: Supports both JSON and text-based tool/query extraction. +- **Embedding cache** with LRU eviction to minimize redundant API calls. +- **Configurable JSONPath expressions** for payload extraction. +- **Mixed mode support** (extract query from JSON and tools from text, or vice versa). + +## Configuration + +### User Parameters (API Definition) + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| selectionMode | string | Yes | `By Rank` | Method used to filter tools: `By Rank` (selects top-K) or `By Threshold` (selects based on score). | +| limit | integer | No | `5` | The number of most relevant tools to include (used if selectionMode is `By Rank`). | +| threshold | number | No | `0.7` | Similarity threshold for filtering (0.0 to 1.0). Only tools with a score above this value are included (used if selectionMode is `By Threshold`). | +| queryJSONPath | string | No | `$.messages[-1].content` | JSONPath expression to extract the user's query from the request body. | +| toolsJSONPath | string | No | `$.tools` | JSONPath expression to extract the tools array from the request body (used when `toolsIsJson` is true). | +| userQueryIsJson | boolean | No | `true` | Specifies format of user query. `true`: use `queryJSONPath`. `false`: extract from text using `` tags. | +| toolsIsJson | boolean | No | `true` | Specifies format of tools definition. `true`: use `toolsJSONPath`. `false`: extract from text using ``/`` tags. | + +### System Parameters (From config.toml) + +These parameters are configured in the gateway's `config.toml` to set up the embedding provider. + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| embeddingProvider | string | Yes | - | Embedding provider: `OPENAI`, `MISTRAL`, or `AZURE_OPENAI`. | +| embeddingEndpoint | string | Yes | - | Endpoint URL for the embedding service. | +| embeddingModel | string | Conditional | - | Model name (e.g., `text-embedding-3-small` or `mistral-embed`). Required for `OPENAI` and `MISTRAL`; optional for `AZURE_OPENAI` (deployment name is derived from the endpoint). | +| apiKey | string | Yes | - | API key for the embedding service. | + +#### Sample System Configuration + +Add the following configuration section under the root level in your `config.toml` file: + +```toml +embedding_provider = "MISTRAL" # Supported: MISTRAL, OPENAI, AZURE_OPENAI +embedding_provider_endpoint = "https://api.mistral.ai/v1/embeddings" +embedding_provider_model = "mistral-embed" +embedding_provider_dimension = 1024 +embedding_provider_api_key = "" +``` + +### build.yaml + +Add the following entry to the `policies` section in `/gateway/build.yaml`: + +```yaml +- name: semantic-tool-filtering + gomodule: github.com/wso2/gateway-controllers/policies/semantic-tool-filtering@v0 +``` + +## Reference Scenarios + +### Scenario 1: Filtering Tools by Rank (JSON Format) + +This scenario demonstrates filtering tools to select the top 3 most relevant ones based on a user query in a JSON payload. + +**Configuration:** + +```yaml +type: http +policies: + - policy: + name: semantic-tool-filtering + parameters: + selectionMode: "By Rank" + limit: 3 + queryJSONPath: "$.contents[0].parts[0].text" + toolsJSONPath: "$.tools[0].function_declarations" + userQueryIsJson: true + toolsIsJson: true +``` + +**Request:** + +```json +{ + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "Get weather forecast. what are the tools you have?" + } + ] + } + ], + "tools": [ + { + "function_declarations": [ + { + "name": "get_weather", + "description": "Get current weather and 7-day forecast for a location.", + "parameters": { "type": "OBJECT", "properties": { "location": { "type": "string" } }, "required": ["location"] } + }, + { + "name": "book_venue", + "description": "Reserve a conference room or meeting space.", + "parameters": { "type": "OBJECT", "properties": { "location": { "type": "string" } }, "required": ["location"] } + }, + { + "name": "send_email", + "description": "Send an email to a specific recipient.", + "parameters": { "type": "OBJECT", "properties": { "recipient": { "type": "string" } }, "required": ["recipient"] } + } + ] + } + ] +} +``` + +The policy will interpret the request, calculate embeddings, and filter the `tools` array to include only the top 3 matches (e.g., `get_weather`, `book_venue`, `send_email`). + +### Scenario 2: Filtering Tools by Threshold + +In this scenario, only tools with a semantic similarity score of 0.7 or higher are included. + +**Configuration:** + +```yaml +type: http +policies: + - policy: + name: semantic-tool-filtering + parameters: + selectionMode: "By Threshold" + threshold: 0.7 +``` + +### Scenario 3: Text Format (XML-like Tags) + +This scenario handles cases where the user query and tool definitions are embedded in a text payload using custom tags. + +**Configuration:** + +```yaml +type: http +policies: + - policy: + name: semantic-tool-filtering + parameters: + selectionMode: "By Rank" + limit: 3 + userQueryIsJson: false + toolsIsJson: false +``` + +**Request Body:** + +```json +{ + "contents": [ + { + "parts": [ + { + "text": "You are a logistics assistant with access to the following tools:\n\nget_weatherGet current weather and 7-day forecast for a location\nbook_venueReserve meeting spaces\n\nI'm planning a corporate retreat in Denver next weekend. Check the weather." + } + ] + } + ] +} +``` + +The policy extracts `` as the query and ``/`` as tools, then performs filtering. After the filtering process, the tags are removed. diff --git a/docs/semantic-tool-filtering/v0.1/metadata.json b/docs/semantic-tool-filtering/v0.1/metadata.json new file mode 100644 index 00000000..7502ae46 --- /dev/null +++ b/docs/semantic-tool-filtering/v0.1/metadata.json @@ -0,0 +1,11 @@ +{ + "name": "semantic-tool-filtering", + "displayName": "Semantic Tool Filtering", + "version": "0.1", + "provider": "WSO2", + "categories": [ + "Guardrails", + "AI" + ], + "description": "Dynamically filters the tools provided within an API request based on their semantic relevance to the user query. This policy extracts both the query and the tool definitions from the incoming payload, generates embeddings for the query, and performs a similarity search against the provided tools. It then replaces the original tools array with a filtered subset, optimizing the request before it reaches the LLM." +} diff --git a/policies/semantic-tool-filtering/embeddingcache.go b/policies/semantic-tool-filtering/embeddingcache.go new file mode 100644 index 00000000..c5ec9412 --- /dev/null +++ b/policies/semantic-tool-filtering/embeddingcache.go @@ -0,0 +1,455 @@ +/* + * Copyright (c) 2026, WSO2 LLC. (http://www.wso2.org) All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package semantictoolfiltering + +import ( + "crypto/sha256" + "encoding/hex" + "log/slog" + "sync" + "time" +) + +// Cache limits configuration +const ( + DefaultMaxAPIs = 25 // Maximum number of APIs to store in cache + DefaultMaxToolsPerAPI = 200 // Maximum number of tools per API +) + +// EmbeddingEntry stores the tool name, its embedding vector, and last access time +type EmbeddingEntry struct { + Name string + Embedding []float32 + LastAccessed time.Time +} + +// APICache wraps the tool cache with API-level metadata +type APICache struct { + Tools map[string]*EmbeddingEntry // Key: SHA-256 hash of tool description + LastAccessed time.Time +} + +// APIEmbeddingCache stores embeddings for a specific API +// Key: SHA-256 hash of tool description → Value: EmbeddingEntry +type APIEmbeddingCache map[string]*EmbeddingEntry + +// EmbeddingCacheStore is a global singleton for storing embeddings per API +type EmbeddingCacheStore struct { + mu sync.RWMutex + cache map[string]*APICache // Key: API ID → Value: APICache + maxAPIs int + maxToolsPerAPI int +} + +// Singleton instance for EmbeddingCacheStore +var ( + embeddingCacheInstance *EmbeddingCacheStore + embeddingCacheOnce sync.Once +) + +// GetEmbeddingCacheStoreInstance returns the global singleton instance +func GetEmbeddingCacheStoreInstance() *EmbeddingCacheStore { + embeddingCacheOnce.Do(func() { + embeddingCacheInstance = &EmbeddingCacheStore{ + cache: make(map[string]*APICache), + maxAPIs: DefaultMaxAPIs, + maxToolsPerAPI: DefaultMaxToolsPerAPI, + } + }) + return embeddingCacheInstance +} + +// SetCacheLimits updates the cache limits for APIs and tools per API +func (ecs *EmbeddingCacheStore) SetCacheLimits(maxAPIs, maxToolsPerAPI int) { + ecs.mu.Lock() + defer ecs.mu.Unlock() + + if maxAPIs > 0 { + ecs.maxAPIs = maxAPIs + } + if maxToolsPerAPI > 0 { + ecs.maxToolsPerAPI = maxToolsPerAPI + } +} + +// GetCacheLimits returns the current cache limits +func (ecs *EmbeddingCacheStore) GetCacheLimits() (maxAPIs, maxToolsPerAPI int) { + ecs.mu.RLock() + defer ecs.mu.RUnlock() + return ecs.maxAPIs, ecs.maxToolsPerAPI +} + +// findLRUAPI finds the least recently used API ID (must be called with lock held) +func (ecs *EmbeddingCacheStore) findLRUAPI() string { + var lruAPIId string + var oldestTime time.Time + first := true + + for apiId, apiCache := range ecs.cache { + if first || apiCache.LastAccessed.Before(oldestTime) { + oldestTime = apiCache.LastAccessed + lruAPIId = apiId + first = false + } + } + return lruAPIId +} + +// findLRUTool finds the least recently used tool hash key in an API cache (must be called with lock held) +func (ecs *EmbeddingCacheStore) findLRUTool(apiCache *APICache) string { + var lruHashKey string + var oldestTime time.Time + first := true + + for hashKey, entry := range apiCache.Tools { + if first || entry.LastAccessed.Before(oldestTime) { + oldestTime = entry.LastAccessed + lruHashKey = hashKey + first = false + } + } + return lruHashKey +} + +// evictLRUAPIIfNeeded removes the LRU API if cache is at capacity (must be called with lock held) +func (ecs *EmbeddingCacheStore) evictLRUAPIIfNeeded() { + if len(ecs.cache) >= ecs.maxAPIs { + lruAPIId := ecs.findLRUAPI() + if lruAPIId != "" { + slog.Debug("Evicting LRU API", "apiId", lruAPIId, "currentSize", len(ecs.cache), "maxSize", ecs.maxAPIs) + delete(ecs.cache, lruAPIId) + slog.Debug("LRU API evicted", "evictedApiId", lruAPIId, "newSize", len(ecs.cache)) + } else { + slog.Debug("No LRU API found to evict", "currentSize", len(ecs.cache), "maxSize", ecs.maxAPIs) + } + } +} + +// evictLRUToolIfNeeded removes the LRU tool from an API cache if at capacity (must be called with lock held) +func (ecs *EmbeddingCacheStore) evictLRUToolIfNeeded(apiCache *APICache) { + if len(apiCache.Tools) >= ecs.maxToolsPerAPI { + lruHashKey := ecs.findLRUTool(apiCache) + if lruHashKey != "" { + toolName := "" + if entry, exists := apiCache.Tools[lruHashKey]; exists { + toolName = entry.Name + } + slog.Debug("Evicting LRU tool", "toolName", toolName, "hash", safeHashPrefix(lruHashKey), "currentSize", len(apiCache.Tools), "maxSize", ecs.maxToolsPerAPI) + delete(apiCache.Tools, lruHashKey) + } + } +} + +// safeHashPrefix returns a safe prefix of the hash string for logging +// Returns the first 16 characters if available, otherwise the full string or "" +func safeHashPrefix(hash string) string { + if len(hash) == 0 { + return "" + } + if len(hash) < 16 { + return hash + } + return hash[:16] +} + +// HashDescription computes SHA-256 hash of the tool description +func HashDescription(description string) string { + hash := sha256.Sum256([]byte(description)) + return hex.EncodeToString(hash[:]) +} + +// HasAPI checks if there is a cache entry for the given API ID +func (ecs *EmbeddingCacheStore) HasAPI(apiId string) bool { + ecs.mu.RLock() + defer ecs.mu.RUnlock() + + _, exists := ecs.cache[apiId] + return exists +} + +// GetAPICache returns the embedding cache for a specific API ID +// Returns nil if the API ID doesn't exist in the cache +// Updates the API's last accessed timestamp +func (ecs *EmbeddingCacheStore) GetAPICache(apiId string) APIEmbeddingCache { + ecs.mu.Lock() + defer ecs.mu.Unlock() + + if apiCache, exists := ecs.cache[apiId]; exists { + // Update API last accessed time + apiCache.LastAccessed = time.Now() + + // Create deep copy of the cache + copyCache := make(APIEmbeddingCache, len(apiCache.Tools)) + for k, v := range apiCache.Tools { + entryCopy := *v + entryCopy.Embedding = append([]float32(nil), v.Embedding...) + copyCache[k] = &entryCopy + } + return copyCache + } + return nil +} + +// AddAPICache creates a new empty cache for the given API ID +// If a cache already exists for this API ID, it does nothing +// Evicts the LRU API if cache is at capacity +func (ecs *EmbeddingCacheStore) AddAPICache(apiId string) { + ecs.mu.Lock() + defer ecs.mu.Unlock() + + if _, exists := ecs.cache[apiId]; !exists { + slog.Debug("Adding new API cache", "apiId", apiId, "currentCacheSize", len(ecs.cache), "maxAPIs", ecs.maxAPIs) + // Check if we need to evict an API before adding + ecs.evictLRUAPIIfNeeded() + + ecs.cache[apiId] = &APICache{ + Tools: make(map[string]*EmbeddingEntry), + LastAccessed: time.Now(), + } + slog.Debug("API cache added successfully", "apiId", apiId, "newCacheSize", len(ecs.cache)) + } else { + slog.Debug("API cache already exists", "apiId", apiId) + } +} + +// GetEntry retrieves an embedding entry for a specific API and hash key +// Returns nil if not found +// Updates both API and tool last accessed timestamps +func (ecs *EmbeddingCacheStore) GetEntry(apiId, hashKey string) *EmbeddingEntry { + ecs.mu.Lock() + defer ecs.mu.Unlock() + + slog.Debug("GetEntry called", "apiId", apiId, "hashKey", safeHashPrefix(hashKey), "cachedAPIs", ecs.getCachedAPIIds()) + + if apiCache, exists := ecs.cache[apiId]; exists { + if entry, found := apiCache.Tools[hashKey]; found { + // Update timestamps on read + apiCache.LastAccessed = time.Now() + entry.LastAccessed = time.Now() + slog.Debug("GetEntry cache hit", "apiId", apiId, "toolName", entry.Name) + + // Return deep copy to prevent external mutations + entryCopy := *entry + entryCopy.Embedding = append([]float32(nil), entry.Embedding...) + return &entryCopy + } + slog.Debug("GetEntry tool not found in API cache", "apiId", apiId) + } else { + slog.Debug("GetEntry API not found in cache", "apiId", apiId) + } + return nil +} + +// AddEntry adds or updates an embedding entry for a specific API +// If an entry with the same name exists in this API's cache, it removes the old one first +// The hashKey should be SHA-256 hash of the tool description +// Evicts LRU API if API cache is at capacity, and LRU tool if tool cache is at capacity +func (ecs *EmbeddingCacheStore) AddEntry(apiId, hashKey, name string, embedding []float32) { + ecs.mu.Lock() + defer ecs.mu.Unlock() + + slog.Debug("AddEntry called", "apiId", apiId, "toolName", name, "cachedAPIs", ecs.getCachedAPIIds()) + + // Check if API cache exists, if not, check limits and possibly evict + if _, exists := ecs.cache[apiId]; !exists { + slog.Debug("AddEntry creating new API cache", "apiId", apiId, "currentSize", len(ecs.cache), "maxAPIs", ecs.maxAPIs) + ecs.evictLRUAPIIfNeeded() + ecs.cache[apiId] = &APICache{ + Tools: make(map[string]*EmbeddingEntry), + LastAccessed: time.Now(), + } + slog.Debug("AddEntry new API cache created", "apiId", apiId, "newSize", len(ecs.cache)) + } + + apiCache := ecs.cache[apiId] + // Update API last accessed time + apiCache.LastAccessed = time.Now() + + // Check if there's an existing entry with the same name and remove it + for key, entry := range apiCache.Tools { + if entry.Name == name { + delete(apiCache.Tools, key) + break + } + } + + // Check if we need to evict a tool before adding (only if this is a new entry) + if _, exists := apiCache.Tools[hashKey]; !exists { + ecs.evictLRUToolIfNeeded(apiCache) + } + + // Add new entry with current timestamp + apiCache.Tools[hashKey] = &EmbeddingEntry{ + Name: name, + Embedding: append([]float32(nil), embedding...), + LastAccessed: time.Now(), + } + slog.Debug("AddEntry tool added", "apiId", apiId, "toolName", name, "toolsInAPI", len(apiCache.Tools)) +} + +// ToolEntry represents a tool to be added to the cache +type ToolEntry struct { + HashKey string + Name string + Embedding []float32 +} + +// BulkAddResult contains the result of a bulk add operation +type BulkAddResult struct { + Added []string // Names of tools that were added to the cache + Skipped []string // Names of tools that were skipped due to cache limit + Cached []string // Names of tools that were already in cache (updated) +} + +// BulkAddTools adds multiple tools to the cache for a specific API in an optimized way. +// It first checks which tools are already cached, then only adds new tools up to the cache limit. +// This prevents wasteful evictions where a tool is evicted only for the next tool to also need eviction. +// +// Logic: +// 1. Separate tools into already-cached and new tools +// 2. Update timestamps for already-cached tools +// 3. Calculate available slots for new tools +// 4. Only add new tools that fit within the limit, skip the rest +// +// Returns BulkAddResult with lists of added, skipped, and already-cached tools +func (ecs *EmbeddingCacheStore) BulkAddTools(apiId string, tools []ToolEntry) BulkAddResult { + ecs.mu.Lock() + defer ecs.mu.Unlock() + + result := BulkAddResult{ + Added: make([]string, 0), + Skipped: make([]string, 0), + Cached: make([]string, 0), + } + + if len(tools) == 0 { + return result + } + + slog.Debug("BulkAddTools called", "apiId", apiId, "toolCount", len(tools), "maxToolsPerAPI", ecs.maxToolsPerAPI) + + // Check if API cache exists, if not, create it + if _, exists := ecs.cache[apiId]; !exists { + slog.Debug("BulkAddTools creating new API cache", "apiId", apiId) + ecs.evictLRUAPIIfNeeded() + ecs.cache[apiId] = &APICache{ + Tools: make(map[string]*EmbeddingEntry), + LastAccessed: time.Now(), + } + } + + apiCache := ecs.cache[apiId] + apiCache.LastAccessed = time.Now() + + // Separate tools into already-cached and new tools + var newTools []ToolEntry + + for _, tool := range tools { + if entry, exists := apiCache.Tools[tool.HashKey]; exists { + // Tool already exists in cache - update timestamp and embedding + entry.LastAccessed = time.Now() + entry.Embedding = append([]float32(nil), tool.Embedding...) + result.Cached = append(result.Cached, tool.Name) + slog.Debug("BulkAddTools tool already cached", "toolName", tool.Name) + } else { + // Check if there's an existing entry with the same name (different hash) + for key, entry := range apiCache.Tools { + if entry.Name == tool.Name { + // Remove old entry with different hash, will be re-added with new hash + delete(apiCache.Tools, key) + break + } + } + newTools = append(newTools, tool) + } + } + + slog.Debug("BulkAddTools categorized tools", "cached", len(result.Cached), "new", len(newTools)) + + // Calculate available slots for new tools + availableSlots := ecs.maxToolsPerAPI - len(apiCache.Tools) + if availableSlots < 0 { + availableSlots = 0 + } + + slog.Debug("BulkAddTools available slots", "currentTools", len(apiCache.Tools), "maxTools", ecs.maxToolsPerAPI, "availableSlots", availableSlots) + + // Determine how many new tools we can add + toolsToAddCount := len(newTools) + if toolsToAddCount > availableSlots { + // Mark tools that won't fit as skipped + for _, tool := range newTools[availableSlots:] { + result.Skipped = append(result.Skipped, tool.Name) + slog.Debug("BulkAddTools skipping tool due to cache limit", "toolName", tool.Name) + } + toolsToAddCount = availableSlots + } + + // Add the new tools that fit + for i := 0; i < toolsToAddCount; i++ { + tool := newTools[i] + apiCache.Tools[tool.HashKey] = &EmbeddingEntry{ + Name: tool.Name, + Embedding: append([]float32(nil), tool.Embedding...), + LastAccessed: time.Now(), + } + result.Added = append(result.Added, tool.Name) + slog.Debug("BulkAddTools added new tool", "toolName", tool.Name) + } + + slog.Debug("BulkAddTools completed", "apiId", apiId, "added", len(result.Added), "skipped", len(result.Skipped), "cached", len(result.Cached), "totalToolsInCache", len(apiCache.Tools)) + + return result +} + +// RemoveAPI removes all cached embeddings for a specific API +func (ecs *EmbeddingCacheStore) RemoveAPI(apiId string) { + ecs.mu.Lock() + defer ecs.mu.Unlock() + + delete(ecs.cache, apiId) +} + +// ClearAll removes all cached embeddings +func (ecs *EmbeddingCacheStore) ClearAll() { + ecs.mu.Lock() + defer ecs.mu.Unlock() + + ecs.cache = make(map[string]*APICache) +} + +// GetCacheStats returns statistics about the cache +func (ecs *EmbeddingCacheStore) GetCacheStats() (apiCount int, totalEntries int) { + ecs.mu.RLock() + defer ecs.mu.RUnlock() + + apiCount = len(ecs.cache) + for _, apiCache := range ecs.cache { + totalEntries += len(apiCache.Tools) + } + return +} + +// getCachedAPIIds returns a list of currently cached API IDs (must be called with lock held) +func (ecs *EmbeddingCacheStore) getCachedAPIIds() []string { + ids := make([]string, 0, len(ecs.cache)) + for id := range ecs.cache { + ids = append(ids, id) + } + return ids +} diff --git a/policies/semantic-tool-filtering/go.mod b/policies/semantic-tool-filtering/go.mod new file mode 100644 index 00000000..058cef8e --- /dev/null +++ b/policies/semantic-tool-filtering/go.mod @@ -0,0 +1,9 @@ +module github.com/wso2/gateway-controllers/policies/semantic-tool-filtering + +go 1.25.1 + +// require github.com/wso2/api-platform/sdk v0.3.7 + +require github.com/wso2/api-platform/sdk v0.4.5 + +require github.com/wso2/api-platform/sdk/core v0.1.0 diff --git a/policies/semantic-tool-filtering/go.sum b/policies/semantic-tool-filtering/go.sum new file mode 100644 index 00000000..2595cfd8 --- /dev/null +++ b/policies/semantic-tool-filtering/go.sum @@ -0,0 +1,4 @@ +github.com/wso2/api-platform/sdk v0.4.5 h1:1T5bY4W0No8SKZiydg8dt/x0iFKWiRTgquzoQd7EXU4= +github.com/wso2/api-platform/sdk v0.4.5/go.mod h1:Zoaj0dOuHmI96K9rY9w2qunjjT4Lcvlq9AYezL4a3/0= +github.com/wso2/api-platform/sdk/core v0.1.0 h1:yxxlgmrZLztCqHzAW7GdRnUNBFOA2l8AX6fphMZaly0= +github.com/wso2/api-platform/sdk/core v0.1.0/go.mod h1:R114dV8/EAl843+16zi97Cbd5+iPPYjHKdN7Nik3EUw= diff --git a/policies/semantic-tool-filtering/methodsforv1alpha.go b/policies/semantic-tool-filtering/methodsforv1alpha.go new file mode 100644 index 00000000..06e3222c --- /dev/null +++ b/policies/semantic-tool-filtering/methodsforv1alpha.go @@ -0,0 +1,577 @@ +/* + * Copyright (c) 2026, WSO2 LLC. (http://www.wso2.org) All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package semantictoolfiltering + +import ( + "encoding/json" + "fmt" + "log/slog" + "sort" + + policy "github.com/wso2/api-platform/sdk/gateway/policy/v1alpha" + utils "github.com/wso2/api-platform/sdk/utils" +) + +// OnRequest handles request body processing for semantic tool filtering (v1alpha) +func (p *SemanticToolFilteringPolicy) OnRequest(ctx *policy.RequestContext, params map[string]interface{}) policy.RequestAction { + var content []byte + if ctx.Body != nil { + content = ctx.Body.Content + } + + if len(content) == 0 { + slog.Debug("SemanticToolFiltering: Empty request body") + return policy.UpstreamRequestModifications{} + } + + // Handle based on format type (JSON or Text) + if p.userQueryIsJson && p.toolsIsJson { + // Pure JSON mode + return p.handleJSONRequest(ctx, content) + } else if !p.userQueryIsJson && !p.toolsIsJson { + // Pure Text mode + return p.handleTextRequest(ctx, content) + } else { + // Mixed mode + return p.handleMixedRequest(ctx, content) + } +} + +// handleJSONRequest handles requests where both user query and tools are in JSON format +func (p *SemanticToolFilteringPolicy) handleJSONRequest(ctx *policy.RequestContext, content []byte) policy.RequestAction { + // Parse request body as JSON + var requestBody map[string]interface{} + if err := json.Unmarshal(content, &requestBody); err != nil { + return p.buildErrorResponse("Invalid JSON in request body", err) + } + + // Extract user query using JSONPath + userQuery, err := utils.ExtractStringValueFromJsonpath(content, p.queryJSONPath) + if err != nil { + return p.buildErrorResponse("Error extracting user query from JSONPath", err) + } + + if userQuery == "" { + slog.Debug("SemanticToolFiltering: Empty user query") + return policy.UpstreamRequestModifications{} + } + + // Extract tools array using JSONPath + toolsJSON, err := utils.ExtractValueFromJsonpath(requestBody, p.toolsJSONPath) + if err != nil { + return p.buildErrorResponse("Error extracting tools from JSONPath", err) + } + + // Parse tools array + var tools []interface{} + var toolsBytes []byte + switch v := toolsJSON.(type) { + case []byte: + toolsBytes = v + case string: + toolsBytes = []byte(v) + default: + var err error + toolsBytes, err = json.Marshal(v) + if err != nil { + return p.buildErrorResponse("Invalid tools format in request", err) + } + } + if err := json.Unmarshal(toolsBytes, &tools); err != nil { + return p.buildErrorResponse("Invalid tools format in request", err) + } + + if len(tools) == 0 { + slog.Debug("SemanticToolFiltering: No tools to filter") + return policy.UpstreamRequestModifications{} + } + + // Generate embedding for user query + queryEmbedding, err := p.embeddingProvider.GetEmbedding(userQuery) + if err != nil { + slog.Error("SemanticToolFiltering: Error generating query embedding", "error", err) + return p.buildErrorResponse("Error generating query embedding", err) + } + + // Get embedding cache instance + embeddingCache := GetEmbeddingCacheStoreInstance() + apiId := ctx.APIId + + embeddingCache.AddAPICache(apiId) + + // Prepare embedding requests for all valid tools + var embeddingRequests []toolEmbeddingRequest + toolDescMap := make(map[string]string) // hashKey -> toolDesc for similarity calculation + toolMapByHash := make(map[string]map[string]interface{}) // hashKey -> toolMap + + for _, toolRaw := range tools { + toolMap, ok := toolRaw.(map[string]interface{}) + if !ok { + slog.Warn("SemanticToolFiltering: Invalid tool format, skipping") + continue + } + + toolDesc := extractToolDescription(toolMap) + if toolDesc == "" { + slog.Warn("SemanticToolFiltering: No description found for tool, skipping", + "toolName", toolMap["name"]) + continue + } + + toolName, _ := toolMap["name"].(string) + descHash := p.getCacheKey(toolDesc) + + embeddingRequests = append(embeddingRequests, toolEmbeddingRequest{ + Name: toolName, + Description: toolDesc, + HashKey: descHash, + }) + toolDescMap[descHash] = toolDesc + toolMapByHash[descHash] = toolMap + } + + // Process embeddings with proper cache management (avoids cascade evictions) + embeddingResults := p.processToolEmbeddingsWithCache(embeddingCache, apiId, embeddingRequests) + + // Calculate similarity scores for tools that have embeddings + toolsWithScores := make([]ToolWithScore, 0, len(embeddingResults)) + for hashKey, result := range embeddingResults { + toolMap := toolMapByHash[hashKey] + if toolMap == nil { + continue + } + + similarity, err := cosineSimilarity(queryEmbedding, result.Embedding) + if err != nil { + slog.Warn("SemanticToolFiltering: Error calculating similarity, skipping", + "error", err, "toolName", result.Name) + continue + } + + toolsWithScores = append(toolsWithScores, ToolWithScore{ + Tool: toolMap, + Score: similarity, + }) + } + + if len(toolsWithScores) == 0 { + slog.Debug("SemanticToolFiltering: No valid tools after embedding generation") + return policy.UpstreamRequestModifications{} + } + + // Filter tools based on selection mode + filteredTools := p.filterTools(toolsWithScores) + + slog.Debug("SemanticToolFiltering: Filtered tools", + "originalCount", len(tools), + "filteredCount", len(filteredTools), + "selectionMode", p.selectionMode) + + // Update request body with filtered tools + if err := updateToolsInRequestBody(&requestBody, p.toolsJSONPath, filteredTools); err != nil { + return p.buildErrorResponse("Error updating request body with filtered tools", err) + } + + // Marshal modified request body + modifiedBody, err := json.Marshal(requestBody) + if err != nil { + return p.buildErrorResponse("Error marshaling modified request body", err) + } + + return policy.UpstreamRequestModifications{ + Body: modifiedBody, + } +} + +// handleTextRequest handles requests where both user query and tools are in text format with tags +func (p *SemanticToolFilteringPolicy) handleTextRequest(ctx *policy.RequestContext, content []byte) policy.RequestAction { + contentStr := string(content) + + // Extract user query from tags + userQuery, err := extractUserQueryFromText(contentStr) + if err != nil { + return p.buildErrorResponse("Error extracting user query from text", err) + } + + if userQuery == "" { + slog.Debug("SemanticToolFiltering: Empty user query") + return policy.UpstreamRequestModifications{} + } + + // Extract tools from and tags + textTools, err := extractToolsFromText(contentStr) + if err != nil { + return p.buildErrorResponse("Error extracting tools from text", err) + } + + if len(textTools) == 0 { + slog.Debug("SemanticToolFiltering: No tools to filter") + return policy.UpstreamRequestModifications{} + } + + // Generate embedding for user query + queryEmbedding, err := p.embeddingProvider.GetEmbedding(userQuery) + + if err != nil { + slog.Error("SemanticToolFiltering: Error generating query embedding", "error", err) + return p.buildErrorResponse("Error generating query embedding", err) + } + + // Get embedding cache instance + embeddingCache := GetEmbeddingCacheStoreInstance() + apiId := ctx.APIId + + embeddingCache.AddAPICache(apiId) + + // Prepare embedding requests for all text tools + var embeddingRequests []toolEmbeddingRequest + textToolByHash := make(map[string]TextTool) // hashKey -> TextTool + + for _, tool := range textTools { + toolText := fmt.Sprintf("%s: %s", tool.Name, tool.Description) + textHash := p.getCacheKey(toolText) + + embeddingRequests = append(embeddingRequests, toolEmbeddingRequest{ + Name: tool.Name, + Description: toolText, + HashKey: textHash, + }) + textToolByHash[textHash] = tool + } + + // Process embeddings with proper cache management (avoids cascade evictions) + embeddingResults := p.processToolEmbeddingsWithCache(embeddingCache, apiId, embeddingRequests) + + // Calculate similarity scores for tools that have embeddings + type TextToolWithScore struct { + Tool TextTool + Score float64 + } + toolsWithScores := make([]TextToolWithScore, 0, len(embeddingResults)) + + for hashKey, result := range embeddingResults { + textTool, ok := textToolByHash[hashKey] + if !ok { + continue + } + + similarity, err := cosineSimilarity(queryEmbedding, result.Embedding) + if err != nil { + slog.Warn("SemanticToolFiltering: Error calculating similarity, skipping", + "error", err, "toolName", result.Name) + continue + } + + toolsWithScores = append(toolsWithScores, TextToolWithScore{ + Tool: textTool, + Score: similarity, + }) + } + + if len(toolsWithScores) == 0 { + slog.Debug("SemanticToolFiltering: No valid tools after embedding generation") + return policy.UpstreamRequestModifications{} + } + + // Sort by score in descending order + sort.Slice(toolsWithScores, func(i, j int) bool { + return toolsWithScores[i].Score > toolsWithScores[j].Score + }) + + // Filter based on selection mode + filteredToolNames := make(map[string]bool) + switch p.selectionMode { + case SelectionModeTopK: + limit := p.topK + if limit > len(toolsWithScores) { + limit = len(toolsWithScores) + } + for i := 0; i < limit; i++ { + filteredToolNames[toolsWithScores[i].Tool.Name] = true + } + + case SelectionModeThreshold: + for _, item := range toolsWithScores { + if item.Score >= p.threshold { + filteredToolNames[item.Tool.Name] = true + } + } + } + + // Rebuild text content with only filtered tools and strip all tags + modifiedContent := rebuildTextWithFilteredTools(contentStr, textTools, filteredToolNames) + modifiedContent = stripAllTags(modifiedContent) + + slog.Debug("SemanticToolFiltering: Filtered text tools", + "originalCount", len(textTools), + "filteredCount", len(filteredToolNames), + "selectionMode", p.selectionMode) + + return policy.UpstreamRequestModifications{ + Body: []byte(modifiedContent), + } +} + +// handleMixedRequest handles requests where user query and tools have different formats +func (p *SemanticToolFilteringPolicy) handleMixedRequest(ctx *policy.RequestContext, content []byte) policy.RequestAction { + // For mixed mode, parse based on each component's format + contentStr := string(content) + var userQuery string + var err error + + // Extract user query based on format + if p.userQueryIsJson { + var requestBody map[string]interface{} + if err := json.Unmarshal(content, &requestBody); err != nil { + return p.buildErrorResponse("Invalid JSON in request body", err) + } + userQuery, err = utils.ExtractStringValueFromJsonpath(content, p.queryJSONPath) + if err != nil { + return p.buildErrorResponse("Error extracting user query from JSONPath", err) + } + } else { + userQuery, err = extractUserQueryFromText(contentStr) + if err != nil { + return p.buildErrorResponse("Error extracting user query from text", err) + } + } + + if userQuery == "" { + slog.Debug("SemanticToolFiltering: Empty user query") + return policy.UpstreamRequestModifications{} + } + + // Generate embedding for user query + queryEmbedding, err := p.embeddingProvider.GetEmbedding(userQuery) + if err != nil { + slog.Error("SemanticToolFiltering: Error generating query embedding", "error", err) + return p.buildErrorResponse("Error generating query embedding", err) + } + + // Get embedding cache instance + embeddingCache := GetEmbeddingCacheStoreInstance() + apiId := ctx.APIId + + embeddingCache.AddAPICache(apiId) + + // Handle tools based on format + if p.toolsIsJson { + // Tools are in JSON format + var requestBody map[string]interface{} + if err := json.Unmarshal(content, &requestBody); err != nil { + return p.buildErrorResponse("Invalid JSON in request body", err) + } + + toolsJSON, err := utils.ExtractValueFromJsonpath(requestBody, p.toolsJSONPath) + if err != nil { + return p.buildErrorResponse("Error extracting tools from JSONPath", err) + } + + var tools []interface{} + var toolsBytes []byte + switch v := toolsJSON.(type) { + case []byte: + toolsBytes = v + case string: + toolsBytes = []byte(v) + default: + var err error + toolsBytes, err = json.Marshal(v) + if err != nil { + return p.buildErrorResponse("Invalid tools format in request", err) + } + } + if err := json.Unmarshal(toolsBytes, &tools); err != nil { + return p.buildErrorResponse("Invalid tools format in request", err) + } + + if len(tools) == 0 { + slog.Debug("SemanticToolFiltering: No tools to filter") + return policy.UpstreamRequestModifications{} + } + + var embeddingRequests []toolEmbeddingRequest + toolMapByHash := make(map[string]map[string]interface{}) + + for _, toolRaw := range tools { + toolMap, ok := toolRaw.(map[string]interface{}) + if !ok { + slog.Warn("SemanticToolFiltering: Invalid tool format, skipping") + continue + } + + toolDesc := extractToolDescription(toolMap) + if toolDesc == "" { + slog.Warn("SemanticToolFiltering: No description found for tool, skipping") + continue + } + + toolName, _ := toolMap["name"].(string) + descHash := p.getCacheKey(toolDesc) + + embeddingRequests = append(embeddingRequests, toolEmbeddingRequest{ + Name: toolName, + Description: toolDesc, + HashKey: descHash, + }) + toolMapByHash[descHash] = toolMap + } + + embeddingResults := p.processToolEmbeddingsWithCache(embeddingCache, apiId, embeddingRequests) + + toolsWithScores := make([]ToolWithScore, 0, len(embeddingResults)) + for hashKey, result := range embeddingResults { + toolMap := toolMapByHash[hashKey] + if toolMap == nil { + continue + } + + similarity, err := cosineSimilarity(queryEmbedding, result.Embedding) + if err != nil { + slog.Warn("SemanticToolFiltering: Error calculating similarity, skipping", "error", err) + continue + } + + toolsWithScores = append(toolsWithScores, ToolWithScore{ + Tool: toolMap, + Score: similarity, + }) + } + + if len(toolsWithScores) == 0 { + slog.Debug("SemanticToolFiltering: No valid tools after embedding generation") + return policy.UpstreamRequestModifications{} + } + + filteredTools := p.filterTools(toolsWithScores) + + if err := updateToolsInRequestBody(&requestBody, p.toolsJSONPath, filteredTools); err != nil { + return p.buildErrorResponse("Error updating request body with filtered tools", err) + } + + modifiedBody, err := json.Marshal(requestBody) + if err != nil { + return p.buildErrorResponse("Error marshaling modified request body", err) + } + + return policy.UpstreamRequestModifications{ + Body: modifiedBody, + } + } else { + // Tools are in text format + textTools, err := extractToolsFromText(contentStr) + if err != nil { + return p.buildErrorResponse("Error extracting tools from text", err) + } + + if len(textTools) == 0 { + slog.Debug("SemanticToolFiltering: No tools to filter") + return policy.UpstreamRequestModifications{} + } + + var embeddingRequests []toolEmbeddingRequest + textToolByHash := make(map[string]TextTool) + + for _, tool := range textTools { + toolText := fmt.Sprintf("%s: %s", tool.Name, tool.Description) + textHash := p.getCacheKey(toolText) + + embeddingRequests = append(embeddingRequests, toolEmbeddingRequest{ + Name: tool.Name, + Description: toolText, + HashKey: textHash, + }) + textToolByHash[textHash] = tool + } + + embeddingResults := p.processToolEmbeddingsWithCache(embeddingCache, apiId, embeddingRequests) + + type TextToolWithScore struct { + Tool TextTool + Score float64 + } + toolsWithScores := make([]TextToolWithScore, 0, len(embeddingResults)) + + for hashKey, result := range embeddingResults { + textTool, ok := textToolByHash[hashKey] + if !ok { + continue + } + + similarity, err := cosineSimilarity(queryEmbedding, result.Embedding) + if err != nil { + slog.Warn("SemanticToolFiltering: Error calculating similarity, skipping", "error", err) + continue + } + + toolsWithScores = append(toolsWithScores, TextToolWithScore{ + Tool: textTool, + Score: similarity, + }) + } + + if len(toolsWithScores) == 0 { + slog.Debug("SemanticToolFiltering: No valid tools after embedding generation") + return policy.UpstreamRequestModifications{} + } + + sort.Slice(toolsWithScores, func(i, j int) bool { + return toolsWithScores[i].Score > toolsWithScores[j].Score + }) + + filteredToolNames := make(map[string]bool) + switch p.selectionMode { + case SelectionModeTopK: + limit := p.topK + if limit > len(toolsWithScores) { + limit = len(toolsWithScores) + } + for i := 0; i < limit; i++ { + filteredToolNames[toolsWithScores[i].Tool.Name] = true + } + + case SelectionModeThreshold: + for _, item := range toolsWithScores { + if item.Score >= p.threshold { + filteredToolNames[item.Tool.Name] = true + } + } + } + + modifiedContent := rebuildTextWithFilteredTools(contentStr, textTools, filteredToolNames) + modifiedContent = stripAllTags(modifiedContent) + + return policy.UpstreamRequestModifications{ + Body: []byte(modifiedContent), + } + } +} + +// buildErrorResponse builds an error response +func (p *SemanticToolFilteringPolicy) buildErrorResponse(message string, err error) policy.RequestAction { + // Log a warning with error details for diagnostics, but do not expose + // internal error details to clients. Continue the request unmodified. + if err != nil { + slog.Warn("SemanticToolFiltering: "+message, "error", err) + } else { + slog.Warn("SemanticToolFiltering: " + message) + } + + // Return a pass-through action so the original request proceeds unchanged. + return policy.UpstreamRequestModifications{} +} diff --git a/policies/semantic-tool-filtering/methodsforv1alpha2.go b/policies/semantic-tool-filtering/methodsforv1alpha2.go new file mode 100644 index 00000000..89ea87f8 --- /dev/null +++ b/policies/semantic-tool-filtering/methodsforv1alpha2.go @@ -0,0 +1,581 @@ +/* + * Copyright (c) 2026, WSO2 LLC. (http://www.wso2.org) All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package semantictoolfiltering + +import ( + "encoding/json" + "fmt" + "log/slog" + "sort" + + policyv1alpha2 "github.com/wso2/api-platform/sdk/core/policy/v1alpha2" + utils "github.com/wso2/api-platform/sdk/utils" +) + +// OnRequestBody is the v1alpha2 factory entry point (loaded by v1alpha2 kernels). +func (p *SemanticToolFilteringPolicy) OnRequestBody(ctx *policyv1alpha2.RequestContext, _ map[string]interface{}) policyv1alpha2.RequestAction { + return p.processRequestBody(ctx) +} + +func (p *SemanticToolFilteringPolicy) processRequestBody(ctx *policyv1alpha2.RequestContext) policyv1alpha2.RequestAction { + var content []byte + if ctx.Body != nil { + content = ctx.Body.Content + } + + if len(content) == 0 { + slog.Debug("SemanticToolFiltering: Empty request body") + return policyv1alpha2.UpstreamRequestModifications{} + } + + // Handle based on format type (JSON or Text) + if p.userQueryIsJson && p.toolsIsJson { + // Pure JSON mode + return p.handleJSONRequestV2(ctx, content) + } else if !p.userQueryIsJson && !p.toolsIsJson { + // Pure Text mode + return p.handleTextRequestV2(ctx, content) + } else { + // Mixed mode + return p.handleMixedRequestV2(ctx, content) + } +} + +// handleJSONRequestV2 handles requests where both user query and tools are in JSON format (v1alpha2) +func (p *SemanticToolFilteringPolicy) handleJSONRequestV2(ctx *policyv1alpha2.RequestContext, content []byte) policyv1alpha2.RequestAction { + // Parse request body as JSON + var requestBody map[string]interface{} + if err := json.Unmarshal(content, &requestBody); err != nil { + return p.buildErrorResponseV2("Invalid JSON in request body", err) + } + + // Extract user query using JSONPath + userQuery, err := utils.ExtractStringValueFromJsonpath(content, p.queryJSONPath) + if err != nil { + return p.buildErrorResponseV2("Error extracting user query from JSONPath", err) + } + + if userQuery == "" { + slog.Debug("SemanticToolFiltering: Empty user query") + return policyv1alpha2.UpstreamRequestModifications{} + } + + // Extract tools array using JSONPath + toolsJSON, err := utils.ExtractValueFromJsonpath(requestBody, p.toolsJSONPath) + if err != nil { + return p.buildErrorResponseV2("Error extracting tools from JSONPath", err) + } + + // Parse tools array + var tools []interface{} + var toolsBytes []byte + switch v := toolsJSON.(type) { + case []byte: + toolsBytes = v + case string: + toolsBytes = []byte(v) + default: + var err error + toolsBytes, err = json.Marshal(v) + if err != nil { + return p.buildErrorResponseV2("Invalid tools format in request", err) + } + } + if err := json.Unmarshal(toolsBytes, &tools); err != nil { + return p.buildErrorResponseV2("Invalid tools format in request", err) + } + + if len(tools) == 0 { + slog.Debug("SemanticToolFiltering: No tools to filter") + return policyv1alpha2.UpstreamRequestModifications{} + } + + // Generate embedding for user query + queryEmbedding, err := p.embeddingProvider.GetEmbedding(userQuery) + if err != nil { + slog.Error("SemanticToolFiltering: Error generating query embedding", "error", err) + return p.buildErrorResponseV2("Error generating query embedding", err) + } + + // Get embedding cache instance + embeddingCache := GetEmbeddingCacheStoreInstance() + apiId := ctx.APIId + + embeddingCache.AddAPICache(apiId) + + // Prepare embedding requests for all valid tools + var embeddingRequests []toolEmbeddingRequest + toolDescMap := make(map[string]string) // hashKey -> toolDesc for similarity calculation + toolMapByHash := make(map[string]map[string]interface{}) // hashKey -> toolMap + + for _, toolRaw := range tools { + toolMap, ok := toolRaw.(map[string]interface{}) + if !ok { + slog.Warn("SemanticToolFiltering: Invalid tool format, skipping") + continue + } + + toolDesc := extractToolDescription(toolMap) + if toolDesc == "" { + slog.Warn("SemanticToolFiltering: No description found for tool, skipping", + "toolName", toolMap["name"]) + continue + } + + toolName, _ := toolMap["name"].(string) + descHash := p.getCacheKey(toolDesc) + + embeddingRequests = append(embeddingRequests, toolEmbeddingRequest{ + Name: toolName, + Description: toolDesc, + HashKey: descHash, + }) + toolDescMap[descHash] = toolDesc + toolMapByHash[descHash] = toolMap + } + + // Process embeddings with proper cache management (avoids cascade evictions) + embeddingResults := p.processToolEmbeddingsWithCache(embeddingCache, apiId, embeddingRequests) + + // Calculate similarity scores for tools that have embeddings + toolsWithScores := make([]ToolWithScore, 0, len(embeddingResults)) + for hashKey, result := range embeddingResults { + toolMap := toolMapByHash[hashKey] + if toolMap == nil { + continue + } + + similarity, err := cosineSimilarity(queryEmbedding, result.Embedding) + if err != nil { + slog.Warn("SemanticToolFiltering: Error calculating similarity, skipping", + "error", err, "toolName", result.Name) + continue + } + + toolsWithScores = append(toolsWithScores, ToolWithScore{ + Tool: toolMap, + Score: similarity, + }) + } + + if len(toolsWithScores) == 0 { + slog.Debug("SemanticToolFiltering: No valid tools after embedding generation") + return policyv1alpha2.UpstreamRequestModifications{} + } + + // Filter tools based on selection mode + filteredTools := p.filterTools(toolsWithScores) + + slog.Debug("SemanticToolFiltering: Filtered tools", + "originalCount", len(tools), + "filteredCount", len(filteredTools), + "selectionMode", p.selectionMode) + + // Update request body with filtered tools + if err := updateToolsInRequestBody(&requestBody, p.toolsJSONPath, filteredTools); err != nil { + return p.buildErrorResponseV2("Error updating request body with filtered tools", err) + } + + // Marshal modified request body + modifiedBody, err := json.Marshal(requestBody) + if err != nil { + return p.buildErrorResponseV2("Error marshaling modified request body", err) + } + + return policyv1alpha2.UpstreamRequestModifications{ + Body: modifiedBody, + } +} + +// handleTextRequestV2 handles requests where both user query and tools are in text format with tags (v1alpha2) +func (p *SemanticToolFilteringPolicy) handleTextRequestV2(ctx *policyv1alpha2.RequestContext, content []byte) policyv1alpha2.RequestAction { + contentStr := string(content) + + // Extract user query from tags + userQuery, err := extractUserQueryFromText(contentStr) + if err != nil { + return p.buildErrorResponseV2("Error extracting user query from text", err) + } + + if userQuery == "" { + slog.Debug("SemanticToolFiltering: Empty user query") + return policyv1alpha2.UpstreamRequestModifications{} + } + + // Extract tools from and tags + textTools, err := extractToolsFromText(contentStr) + if err != nil { + return p.buildErrorResponseV2("Error extracting tools from text", err) + } + + if len(textTools) == 0 { + slog.Debug("SemanticToolFiltering: No tools to filter") + return policyv1alpha2.UpstreamRequestModifications{} + } + + // Generate embedding for user query + queryEmbedding, err := p.embeddingProvider.GetEmbedding(userQuery) + + if err != nil { + slog.Error("SemanticToolFiltering: Error generating query embedding", "error", err) + return p.buildErrorResponseV2("Error generating query embedding", err) + } + + // Get embedding cache instance + embeddingCache := GetEmbeddingCacheStoreInstance() + apiId := ctx.APIId + + embeddingCache.AddAPICache(apiId) + + // Prepare embedding requests for all text tools + var embeddingRequests []toolEmbeddingRequest + textToolByHash := make(map[string]TextTool) // hashKey -> TextTool + + for _, tool := range textTools { + toolText := fmt.Sprintf("%s: %s", tool.Name, tool.Description) + textHash := p.getCacheKey(toolText) + + embeddingRequests = append(embeddingRequests, toolEmbeddingRequest{ + Name: tool.Name, + Description: toolText, + HashKey: textHash, + }) + textToolByHash[textHash] = tool + } + + // Process embeddings with proper cache management (avoids cascade evictions) + embeddingResults := p.processToolEmbeddingsWithCache(embeddingCache, apiId, embeddingRequests) + + // Calculate similarity scores for tools that have embeddings + type TextToolWithScore struct { + Tool TextTool + Score float64 + } + toolsWithScores := make([]TextToolWithScore, 0, len(embeddingResults)) + + for hashKey, result := range embeddingResults { + textTool, ok := textToolByHash[hashKey] + if !ok { + continue + } + + similarity, err := cosineSimilarity(queryEmbedding, result.Embedding) + if err != nil { + slog.Warn("SemanticToolFiltering: Error calculating similarity, skipping", + "error", err, "toolName", result.Name) + continue + } + + toolsWithScores = append(toolsWithScores, TextToolWithScore{ + Tool: textTool, + Score: similarity, + }) + } + + if len(toolsWithScores) == 0 { + slog.Debug("SemanticToolFiltering: No valid tools after embedding generation") + return policyv1alpha2.UpstreamRequestModifications{} + } + + // Sort by score in descending order + sort.Slice(toolsWithScores, func(i, j int) bool { + return toolsWithScores[i].Score > toolsWithScores[j].Score + }) + + // Filter based on selection mode + filteredToolNames := make(map[string]bool) + switch p.selectionMode { + case SelectionModeTopK: + limit := p.topK + if limit > len(toolsWithScores) { + limit = len(toolsWithScores) + } + for i := 0; i < limit; i++ { + filteredToolNames[toolsWithScores[i].Tool.Name] = true + } + + case SelectionModeThreshold: + for _, item := range toolsWithScores { + if item.Score >= p.threshold { + filteredToolNames[item.Tool.Name] = true + } + } + } + + // Rebuild text content with only filtered tools and strip all tags + modifiedContent := rebuildTextWithFilteredTools(contentStr, textTools, filteredToolNames) + modifiedContent = stripAllTags(modifiedContent) + + slog.Debug("SemanticToolFiltering: Filtered text tools", + "originalCount", len(textTools), + "filteredCount", len(filteredToolNames), + "selectionMode", p.selectionMode) + + return policyv1alpha2.UpstreamRequestModifications{ + Body: []byte(modifiedContent), + } +} + +// handleMixedRequestV2 handles requests where user query and tools have different formats (v1alpha2) +func (p *SemanticToolFilteringPolicy) handleMixedRequestV2(ctx *policyv1alpha2.RequestContext, content []byte) policyv1alpha2.RequestAction { + // For mixed mode, parse based on each component's format + contentStr := string(content) + var userQuery string + var err error + + // Extract user query based on format + if p.userQueryIsJson { + var requestBody map[string]interface{} + if err := json.Unmarshal(content, &requestBody); err != nil { + return p.buildErrorResponseV2("Invalid JSON in request body", err) + } + userQuery, err = utils.ExtractStringValueFromJsonpath(content, p.queryJSONPath) + if err != nil { + return p.buildErrorResponseV2("Error extracting user query from JSONPath", err) + } + } else { + userQuery, err = extractUserQueryFromText(contentStr) + if err != nil { + return p.buildErrorResponseV2("Error extracting user query from text", err) + } + } + + if userQuery == "" { + slog.Debug("SemanticToolFiltering: Empty user query") + return policyv1alpha2.UpstreamRequestModifications{} + } + + // Generate embedding for user query + queryEmbedding, err := p.embeddingProvider.GetEmbedding(userQuery) + if err != nil { + slog.Error("SemanticToolFiltering: Error generating query embedding", "error", err) + return p.buildErrorResponseV2("Error generating query embedding", err) + } + + // Get embedding cache instance + embeddingCache := GetEmbeddingCacheStoreInstance() + apiId := ctx.APIId + + embeddingCache.AddAPICache(apiId) + + // Handle tools based on format + if p.toolsIsJson { + // Tools are in JSON format + var requestBody map[string]interface{} + if err := json.Unmarshal(content, &requestBody); err != nil { + return p.buildErrorResponseV2("Invalid JSON in request body", err) + } + + toolsJSON, err := utils.ExtractValueFromJsonpath(requestBody, p.toolsJSONPath) + if err != nil { + return p.buildErrorResponseV2("Error extracting tools from JSONPath", err) + } + + var tools []interface{} + var toolsBytes []byte + switch v := toolsJSON.(type) { + case []byte: + toolsBytes = v + case string: + toolsBytes = []byte(v) + default: + var err error + toolsBytes, err = json.Marshal(v) + if err != nil { + return p.buildErrorResponseV2("Invalid tools format in request", err) + } + } + if err := json.Unmarshal(toolsBytes, &tools); err != nil { + return p.buildErrorResponseV2("Invalid tools format in request", err) + } + + if len(tools) == 0 { + slog.Debug("SemanticToolFiltering: No tools to filter") + return policyv1alpha2.UpstreamRequestModifications{} + } + + var embeddingRequests []toolEmbeddingRequest + toolMapByHash := make(map[string]map[string]interface{}) + + for _, toolRaw := range tools { + toolMap, ok := toolRaw.(map[string]interface{}) + if !ok { + slog.Warn("SemanticToolFiltering: Invalid tool format, skipping") + continue + } + + toolDesc := extractToolDescription(toolMap) + if toolDesc == "" { + slog.Warn("SemanticToolFiltering: No description found for tool, skipping") + continue + } + + toolName, _ := toolMap["name"].(string) + descHash := p.getCacheKey(toolDesc) + + embeddingRequests = append(embeddingRequests, toolEmbeddingRequest{ + Name: toolName, + Description: toolDesc, + HashKey: descHash, + }) + toolMapByHash[descHash] = toolMap + } + + embeddingResults := p.processToolEmbeddingsWithCache(embeddingCache, apiId, embeddingRequests) + + toolsWithScores := make([]ToolWithScore, 0, len(embeddingResults)) + for hashKey, result := range embeddingResults { + toolMap := toolMapByHash[hashKey] + if toolMap == nil { + continue + } + + similarity, err := cosineSimilarity(queryEmbedding, result.Embedding) + if err != nil { + slog.Warn("SemanticToolFiltering: Error calculating similarity, skipping", "error", err) + continue + } + + toolsWithScores = append(toolsWithScores, ToolWithScore{ + Tool: toolMap, + Score: similarity, + }) + } + + if len(toolsWithScores) == 0 { + slog.Debug("SemanticToolFiltering: No valid tools after embedding generation") + return policyv1alpha2.UpstreamRequestModifications{} + } + + filteredTools := p.filterTools(toolsWithScores) + + if err := updateToolsInRequestBody(&requestBody, p.toolsJSONPath, filteredTools); err != nil { + return p.buildErrorResponseV2("Error updating request body with filtered tools", err) + } + + modifiedBody, err := json.Marshal(requestBody) + if err != nil { + return p.buildErrorResponseV2("Error marshaling modified request body", err) + } + + return policyv1alpha2.UpstreamRequestModifications{ + Body: modifiedBody, + } + } else { + // Tools are in text format + textTools, err := extractToolsFromText(contentStr) + if err != nil { + return p.buildErrorResponseV2("Error extracting tools from text", err) + } + + if len(textTools) == 0 { + slog.Debug("SemanticToolFiltering: No tools to filter") + return policyv1alpha2.UpstreamRequestModifications{} + } + + var embeddingRequests []toolEmbeddingRequest + textToolByHash := make(map[string]TextTool) + + for _, tool := range textTools { + toolText := fmt.Sprintf("%s: %s", tool.Name, tool.Description) + textHash := p.getCacheKey(toolText) + + embeddingRequests = append(embeddingRequests, toolEmbeddingRequest{ + Name: tool.Name, + Description: toolText, + HashKey: textHash, + }) + textToolByHash[textHash] = tool + } + + embeddingResults := p.processToolEmbeddingsWithCache(embeddingCache, apiId, embeddingRequests) + + type TextToolWithScore struct { + Tool TextTool + Score float64 + } + toolsWithScores := make([]TextToolWithScore, 0, len(embeddingResults)) + + for hashKey, result := range embeddingResults { + textTool, ok := textToolByHash[hashKey] + if !ok { + continue + } + + similarity, err := cosineSimilarity(queryEmbedding, result.Embedding) + if err != nil { + slog.Warn("SemanticToolFiltering: Error calculating similarity, skipping", "error", err) + continue + } + + toolsWithScores = append(toolsWithScores, TextToolWithScore{ + Tool: textTool, + Score: similarity, + }) + } + + if len(toolsWithScores) == 0 { + slog.Debug("SemanticToolFiltering: No valid tools after embedding generation") + return policyv1alpha2.UpstreamRequestModifications{} + } + + sort.Slice(toolsWithScores, func(i, j int) bool { + return toolsWithScores[i].Score > toolsWithScores[j].Score + }) + + filteredToolNames := make(map[string]bool) + switch p.selectionMode { + case SelectionModeTopK: + limit := p.topK + if limit > len(toolsWithScores) { + limit = len(toolsWithScores) + } + for i := 0; i < limit; i++ { + filteredToolNames[toolsWithScores[i].Tool.Name] = true + } + + case SelectionModeThreshold: + for _, item := range toolsWithScores { + if item.Score >= p.threshold { + filteredToolNames[item.Tool.Name] = true + } + } + } + + modifiedContent := rebuildTextWithFilteredTools(contentStr, textTools, filteredToolNames) + modifiedContent = stripAllTags(modifiedContent) + + return policyv1alpha2.UpstreamRequestModifications{ + Body: []byte(modifiedContent), + } + } +} + +// buildErrorResponseV2 builds an error response for v1alpha2 +func (p *SemanticToolFilteringPolicy) buildErrorResponseV2(message string, err error) policyv1alpha2.RequestAction { + // Log a warning with error details for diagnostics, but do not expose + // internal error details to clients. Continue the request unmodified. + if err != nil { + slog.Warn("SemanticToolFiltering: "+message, "error", err) + } else { + slog.Warn("SemanticToolFiltering: " + message) + } + + // Return a pass-through action so the original request proceeds unchanged. + return policyv1alpha2.UpstreamRequestModifications{} +} diff --git a/policies/semantic-tool-filtering/policy-definition.yaml b/policies/semantic-tool-filtering/policy-definition.yaml new file mode 100644 index 00000000..89f629b5 --- /dev/null +++ b/policies/semantic-tool-filtering/policy-definition.yaml @@ -0,0 +1,110 @@ +name: semantic-tool-filtering +version: v0.1.0 +description: | + Dynamically filters the tools provided within an API request based on their semantic relevance to the + user query. This policy extracts both the query and the tool definitions from the incoming payload, + generates an embedding for the query, and performs a similarity search against the provided tools. + It then replaces the original 'tools' array with a filtered subset, optimizing the request before it reaches the LLM. + + Note: This policy requires embedding providers (OpenAI, Mistral, Azure OpenAI) to be configured. + +parameters: + type: object + properties: + selectionMode: + type: string + x-wso2-policy-advanced-param: true + description: | + Method used to filter tools: + - "By Rank": Selects a fixed number of the most relevant tools. + - "By Threshold": Selects all tools exceeding a specific similarity score. + enum: + - By Rank + - By Threshold + default: "By Rank" + limit: + type: integer + x-wso2-policy-advanced-param: true + description: "The number of most relevant tools to include (used if selectionMode is By Rank)." + minimum: 0 + maximum: 20 + default: 5 + threshold: + type: number + x-wso2-policy-advanced-param: true + description: | + Similarity threshold for filtering (0.0 to 1.0). + Only tools with a score above this value are included (used if selectionMode is By Threshold). + minimum: 0.0 + maximum: 1.0 + default: 0.7 + queryJSONPath: + type: string + x-wso2-policy-advanced-param: true + description: | + JSONPath expression to extract the user's query from the request body. + Example: "$.messages[-1].content" + default: "$.messages[-1].content" + toolsJSONPath: + type: string + x-wso2-policy-advanced-param: true + description: | + JSONPath expression to extract the tools array from the request body (used when toolsIsJson is true). + Example: "$.tools" + default: "$.tools" + pattern: '^(?:\$\.)?[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*$' + userQueryIsJson: + type: boolean + x-wso2-policy-advanced-param: true + description: | + Specifies the format of user query extraction: + - true: Extract user query using query JSON path from queryJSONPath parameter + - false: Extract user query from text using ... tags + default: true + toolsIsJson: + type: boolean + x-wso2-policy-advanced-param: true + description: | + Specifies the format of tools definition: + - true: Extract tools using tools JSON path from toolsJSONPath parameter (JSON format) + - false: Extract tools from text using ... and ... tags + + Text format example: + get_weather + Get current weather for a location + search_web + Search the web for information + default: true + required: + - selectionMode + +systemParameters: + type: object + x-wso2-policy-advanced-param: true + properties: + embeddingProvider: + type: string + description: "Embedding provider: OPENAI, MISTRAL, or AZURE_OPENAI" + enum: + - OPENAI + - MISTRAL + - AZURE_OPENAI + "wso2/defaultValue": "${config.embedding_provider}" + embeddingEndpoint: + type: string + description: "Endpoint URL for the embedding service" + "wso2/defaultValue": "${config.embedding_provider_endpoint}" + minLength: 1 + embeddingModel: + type: string + description: "Model name (e.g., text-embedding-3-small or mistral-embed)" + "wso2/defaultValue": "${config.embedding_provider_model}" + apiKey: + type: string + description: "API key for the embedding service" + "wso2/defaultValue": "${config.embedding_provider_api_key}" + minLength: 1 + required: + - embeddingProvider + - embeddingEndpoint + - apiKey diff --git a/policies/semantic-tool-filtering/semantictoolfiltering.go b/policies/semantic-tool-filtering/semantictoolfiltering.go new file mode 100644 index 00000000..caec6a2a --- /dev/null +++ b/policies/semantic-tool-filtering/semantictoolfiltering.go @@ -0,0 +1,846 @@ +/* + * Copyright (c) 2026, WSO2 LLC. (http://www.wso2.org) All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package semantictoolfiltering + +import ( + "fmt" + "log/slog" + "math" + "regexp" + "sort" + "strconv" + "strings" + + policyv1alpha2 "github.com/wso2/api-platform/sdk/core/policy/v1alpha2" + policy "github.com/wso2/api-platform/sdk/gateway/policy/v1alpha" + embeddingproviders "github.com/wso2/api-platform/sdk/utils/embeddingproviders" +) + +const ( + // Selection modes + SelectionModeTopK = "By Rank" + SelectionModeThreshold = "By Threshold" + + // Internal timeout for embedding provider (not exposed in policy definition) + DefaultTimeoutMs = 5000 +) + +// ToolWithScore represents a tool with its similarity score +type ToolWithScore struct { + Tool map[string]interface{} + Score float64 +} + +// TextTool represents a tool parsed from text format +type TextTool struct { + Name string + Description string + StartPos int // Start position in original text + EndPos int // End position in original text (after ) +} + +// SemanticToolFilteringPolicy implements semantic filtering for tool selection +type SemanticToolFilteringPolicy struct { + embeddingConfig embeddingproviders.EmbeddingProviderConfig + embeddingProvider embeddingproviders.EmbeddingProvider + selectionMode string + topK int + threshold float64 + queryJSONPath string + toolsJSONPath string + userQueryIsJson bool + toolsIsJson bool +} + +// getCacheKey generates a cache key that includes the embedding provider and model +// to avoid returning stale/incompatible embeddings if the provider or model changes. +// The key format is: hash(provider:model:description) +func (p *SemanticToolFilteringPolicy) getCacheKey(description string) string { + // Combine provider, model, and description to create a unique cache key + providerModel := fmt.Sprintf("%s:%s", p.embeddingConfig.EmbeddingProvider, p.embeddingConfig.EmbeddingModel) + combinedKey := fmt.Sprintf("%s:%s", providerModel, description) + return HashDescription(combinedKey) +} + +// toolEmbeddingRequest represents a tool that needs embedding processing +type toolEmbeddingRequest struct { + Name string + Description string // The text to generate embedding for + HashKey string // Pre-computed cache key +} + +// toolEmbeddingResult represents a tool with its embedding +type toolEmbeddingResult struct { + Name string + HashKey string + Embedding []float32 + FromCache bool +} + +// processToolEmbeddingsWithCache processes tool embeddings with proper cache management. +// It first checks which tools are already cached, then generates embeddings for ALL +// uncached tools (so they can be used in similarity calculations), but only CACHES +// the ones that fit within the cache limit, avoiding wasteful evictions. +// +// Returns a map of hashKey -> embedding for all successfully processed tools +func (p *SemanticToolFilteringPolicy) processToolEmbeddingsWithCache( + embeddingCache *EmbeddingCacheStore, + apiId string, + requests []toolEmbeddingRequest, +) map[string]toolEmbeddingResult { + results := make(map[string]toolEmbeddingResult) + + if len(requests) == 0 { + return results + } + + // Get cache limits + _, maxToolsPerAPI := embeddingCache.GetCacheLimits() + + // First pass: Check which tools are already cached + var uncachedRequests []toolEmbeddingRequest + + for _, req := range requests { + cachedEntry := embeddingCache.GetEntry(apiId, req.HashKey) + if cachedEntry != nil { + // Cache hit + results[req.HashKey] = toolEmbeddingResult{ + Name: req.Name, + HashKey: req.HashKey, + Embedding: cachedEntry.Embedding, + FromCache: true, + } + slog.Debug("SemanticToolFiltering: Cache hit for tool embedding", "toolName", req.Name) + } else { + uncachedRequests = append(uncachedRequests, req) + } + } + + slog.Debug("SemanticToolFiltering: Cache check complete", + "totalTools", len(requests), + "cachedTools", len(results), + "uncachedTools", len(uncachedRequests)) + + // Calculate available slots for caching new tools + apiCache := embeddingCache.GetAPICache(apiId) + currentCachedCount := 0 + if apiCache != nil { + currentCachedCount = len(apiCache) + } + availableSlots := maxToolsPerAPI - currentCachedCount + if availableSlots < 0 { + availableSlots = 0 + } + + slog.Debug("SemanticToolFiltering: Available cache slots", + "currentCached", currentCachedCount, + "maxToolsPerAPI", maxToolsPerAPI, + "availableSlots", availableSlots) + + // Generate embeddings for ALL uncached tools (for similarity calculation) + // but only cache the ones that fit + var toolEntriesToCache []ToolEntry + toolsCached := 0 + + for _, req := range uncachedRequests { + embedding, err := p.embeddingProvider.GetEmbedding(req.Description) + if err != nil { + slog.Warn("SemanticToolFiltering: Error generating tool embedding, skipping", + "error", err, "toolName", req.Name) + continue + } + + // Add to results so this tool can be used in similarity calculations + results[req.HashKey] = toolEmbeddingResult{ + Name: req.Name, + HashKey: req.HashKey, + Embedding: embedding, + FromCache: false, + } + + // Only cache if we have available slots + if toolsCached < availableSlots { + toolEntriesToCache = append(toolEntriesToCache, ToolEntry{ + HashKey: req.HashKey, + Name: req.Name, + Embedding: embedding, + }) + toolsCached++ + } else { + slog.Debug("SemanticToolFiltering: Tool processed but not cached (limit reached)", + "toolName", req.Name) + } + } + + slog.Debug("SemanticToolFiltering: Embedding generation complete", + "totalProcessed", len(results), + "newlyGenerated", len(uncachedRequests), + "willCache", toolsCached, + "notCached", len(uncachedRequests)-toolsCached) + + // Bulk add embeddings that fit in cache + if len(toolEntriesToCache) > 0 { + bulkResult := embeddingCache.BulkAddTools(apiId, toolEntriesToCache) + slog.Debug("SemanticToolFiltering: Bulk added tool embeddings to cache", + "added", len(bulkResult.Added), + "skipped", len(bulkResult.Skipped), + "alreadyCached", len(bulkResult.Cached)) + } + + return results +} + +// GetPolicy creates a new instance of the semantic tool filtering policy +func GetPolicy( + metadata policy.PolicyMetadata, + params map[string]interface{}, +) (policy.Policy, error) { + p := &SemanticToolFilteringPolicy{} + + // Parse and validate embedding provider configuration (from systemParameters) + if err := parseEmbeddingConfig(params, p); err != nil { + return nil, fmt.Errorf("invalid embedding config") + } + + // Initialize embedding provider + embeddingProvider, err := createEmbeddingProvider(p.embeddingConfig) + if err != nil { + return nil, fmt.Errorf("failed to create embedding provider") + } + p.embeddingProvider = embeddingProvider + + // Parse policy parameters (runtime parameters) + if err := parseParams(params, p); err != nil { + return nil, fmt.Errorf("invalid params") + } + + slog.Debug("SemanticToolFiltering: Policy initialized", + "embeddingProvider", p.embeddingConfig.EmbeddingProvider, + "selectionMode", p.selectionMode, + "topK", p.topK, + "threshold", p.threshold) + + return p, nil +} + +// GetPolicyV2 is the v1alpha2 factory entry point (loaded by v1alpha2 kernels). +func GetPolicyV2( + metadata policyv1alpha2.PolicyMetadata, + params map[string]interface{}, +) (policyv1alpha2.Policy, error) { + return GetPolicy(policy.PolicyMetadata{ + RouteName: metadata.RouteName, + APIId: metadata.APIId, + APIName: metadata.APIName, + APIVersion: metadata.APIVersion, + AttachedTo: policy.Level(metadata.AttachedTo), + }, params) +} + +// parseEmbeddingConfig parses and validates embedding provider configuration +func parseEmbeddingConfig(params map[string]interface{}, p *SemanticToolFilteringPolicy) error { + provider, ok := params["embeddingProvider"].(string) + if !ok || provider == "" { + return fmt.Errorf("'embeddingProvider' is required") + } + + embeddingEndpoint, ok := params["embeddingEndpoint"].(string) + if !ok || embeddingEndpoint == "" { + return fmt.Errorf("'embeddingEndpoint' is required") + } + + // embeddingModel is required for OPENAI and MISTRAL, but not for AZURE_OPENAI + embeddingModel, ok := params["embeddingModel"].(string) + if !ok || embeddingModel == "" { + providerUpper := strings.ToUpper(provider) + if providerUpper == "OPENAI" || providerUpper == "MISTRAL" { + return fmt.Errorf("'embeddingModel' is required for %s provider", provider) + } + // For AZURE_OPENAI, embeddingModel is optional (deployment name is in endpoint) + embeddingModel = "" + } + + apiKey, ok := params["apiKey"].(string) + if !ok || apiKey == "" { + return fmt.Errorf("'apiKey' is required") + } + + // Set header name based on provider type + // Azure OpenAI uses "api-key", others use "Authorization" + authHeaderName := "Authorization" + if strings.ToUpper(provider) == "AZURE_OPENAI" { + authHeaderName = "api-key" + } + + p.embeddingConfig = embeddingproviders.EmbeddingProviderConfig{ + EmbeddingProvider: strings.ToUpper(provider), + EmbeddingEndpoint: embeddingEndpoint, + APIKey: apiKey, + AuthHeaderName: authHeaderName, + EmbeddingModel: embeddingModel, + TimeOut: strconv.Itoa(DefaultTimeoutMs), + } + + return nil +} + +// parseParams parses and validates runtime parameters from the params map +func parseParams(params map[string]interface{}, p *SemanticToolFilteringPolicy) error { + // Optional: selectionMode (default TOP_K) + selectionMode, ok := params["selectionMode"].(string) + if !ok || selectionMode == "" { + selectionMode = SelectionModeTopK + } + if selectionMode != SelectionModeTopK && selectionMode != SelectionModeThreshold { + return fmt.Errorf("'selectionMode' must be By Rank or By Threshold") + } + p.selectionMode = selectionMode + + // Optional: Limit (default 5 as per policy-definition.yaml) + if limitRaw, ok := params["limit"]; ok { + limit, err := extractInt(limitRaw) + if err != nil { + return fmt.Errorf("'limit' must be a number: %w", err) + } + if limit < 0 || limit > 20 { + return fmt.Errorf("'limit' must be between 0 and 20") + } + p.topK = limit + } else { + p.topK = 5 // default from policy-definition.yaml + } + + // Optional: similarityThreshold (default 0.7 as per policy-definition.yaml) + if thresholdRaw, ok := params["threshold"]; ok { + threshold, err := extractFloat64(thresholdRaw) + if err != nil { + return fmt.Errorf("'threshold' must be a number: %w", err) + } + if threshold < 0.0 || threshold > 1.0 { + return fmt.Errorf("'threshold' must be between 0.0 and 1.0") + } + p.threshold = threshold + } else { + p.threshold = 0.7 // default from policy-definition.yaml + } + + // Optional: jsonPath (default "$.messages[-1].content" as per policy-definition.yaml) + if jsonPathRaw, ok := params["queryJSONPath"]; ok { + if jsonPath, ok := jsonPathRaw.(string); ok { + if jsonPath != "" { + p.queryJSONPath = jsonPath + } else { + p.queryJSONPath = "$.messages[-1].content" // default from policy-definition.yaml + } + } else { + return fmt.Errorf("'queryJSONPath' must be a string") + } + } else { + p.queryJSONPath = "$.messages[-1].content" // default from policy-definition.yaml + } + + // Optional: toolsPath (default "$.tools" as per policy-definition.yaml) + if toolsPathRaw, ok := params["toolsJSONPath"]; ok { + if toolsPath, ok := toolsPathRaw.(string); ok { + if toolsPath != "" { + p.toolsJSONPath = toolsPath + } else { + p.toolsJSONPath = "$.tools" // default from policy-definition.yaml + } + } else { + return fmt.Errorf("'toolsJSONPath' must be a string") + } + } else { + p.toolsJSONPath = "$.tools" // default from policy-definition.yaml + } + + // Validate toolsJSONPath pattern - must be a simple dotted path with optional array indices + // Pattern: $.field1.field2[0].field3 or $.tools + // This restriction ensures compatibility with updateToolsInRequestBody which only supports + // simple dotted paths with optional single-level array indices + if err := validateSimpleJSONPath(p.toolsJSONPath); err != nil { + return fmt.Errorf("'toolsJSONPath' validation failed: %w", err) + } + + // Optional: userQueryIsJson (default true - JSON format) + if userQueryIsJsonRaw, ok := params["userQueryIsJson"]; ok { + userQueryIsJson, err := extractBool(userQueryIsJsonRaw) + if err != nil { + return fmt.Errorf("'userQueryIsJson' must be a boolean: %w", err) + } + p.userQueryIsJson = userQueryIsJson + } else { + p.userQueryIsJson = true // default to JSON format + } + + // Optional: toolsIsJson (default true - JSON format) + if toolsIsJsonRaw, ok := params["toolsIsJson"]; ok { + toolsIsJson, err := extractBool(toolsIsJsonRaw) + if err != nil { + return fmt.Errorf("'toolsIsJson' must be a boolean: %w", err) + } + p.toolsIsJson = toolsIsJson + } else { + p.toolsIsJson = true // default to JSON format + } + + return nil +} + +// extractFloat64 safely extracts a float64 from various types +func extractFloat64(value interface{}) (float64, error) { + switch v := value.(type) { + case float64: + return v, nil + case float32: + return float64(v), nil + case int: + return float64(v), nil + case int64: + return float64(v), nil + case string: + parsed, err := strconv.ParseFloat(v, 64) + if err != nil { + return 0, fmt.Errorf("cannot convert %q to float64: %w", v, err) + } + return parsed, nil + default: + return 0, fmt.Errorf("cannot convert %T to float64", value) + } +} + +// extractInt safely extracts an integer from various types +func extractInt(value interface{}) (int, error) { + switch v := value.(type) { + case int: + return v, nil + case int64: + return int(v), nil + case float64: + if v != float64(int(v)) { + return 0, fmt.Errorf("expected an integer but got %v", v) + } + return int(v), nil + case string: + parsed, err := strconv.Atoi(v) + if err != nil { + return 0, fmt.Errorf("cannot convert %q to int: %w", v, err) + } + return parsed, nil + default: + return 0, fmt.Errorf("cannot convert %T to int", value) + } +} + +// extractBool safely extracts a boolean from various types +func extractBool(value interface{}) (bool, error) { + switch v := value.(type) { + case bool: + return v, nil + case string: + lower := strings.ToLower(v) + if lower == "true" || lower == "1" || lower == "yes" { + return true, nil + } + if lower == "false" || lower == "0" || lower == "no" { + return false, nil + } + return false, fmt.Errorf("cannot convert %q to bool", v) + case int: + return v != 0, nil + case float64: + return v != 0, nil + default: + return false, fmt.Errorf("cannot convert %T to bool", value) + } +} + +// simpleJSONPathPattern validates that a JSONPath is a simple dotted path with optional array indices +// Supports patterns like: $.tools, $.data.items, $.results[0].tools, $.a.b[1].c[2].d +// Does NOT support: complex JSONPath expressions like $..[*], $..book[?(@.price<10)], etc. +var simpleJSONPathPattern = regexp.MustCompile(`^\$\.([a-zA-Z_][a-zA-Z0-9_]*(\[\d+\])?\.)*[a-zA-Z_][a-zA-Z0-9_]*(\[\d+\])?$`) + +// validateSimpleJSONPath validates that the given JSONPath is a simple dotted path +// that can be handled by updateToolsInRequestBody +func validateSimpleJSONPath(path string) error { + if path == "" { + return fmt.Errorf("path cannot be empty") + } + + // Must start with "$." + if !strings.HasPrefix(path, "$.") { + return fmt.Errorf("path must start with '$.' prefix, got: %s", path) + } + + // Validate against the simple pattern + if !simpleJSONPathPattern.MatchString(path) { + return fmt.Errorf("path contains unsupported JSONPath syntax; only simple dotted paths with optional array indices are supported (e.g., '$.tools', '$.data.items', '$.results[0].tools'); got: %s", path) + } + + return nil +} + +// createEmbeddingProvider creates a new embedding provider based on the config +func createEmbeddingProvider(config embeddingproviders.EmbeddingProviderConfig) (embeddingproviders.EmbeddingProvider, error) { + var provider embeddingproviders.EmbeddingProvider + + switch config.EmbeddingProvider { + case "OPENAI": + provider = &embeddingproviders.OpenAIEmbeddingProvider{} + case "MISTRAL": + provider = &embeddingproviders.MistralEmbeddingProvider{} + case "AZURE_OPENAI": + provider = &embeddingproviders.AzureOpenAIEmbeddingProvider{} + default: + return nil, fmt.Errorf("unsupported embedding provider: %s", config.EmbeddingProvider) + } + + if err := provider.Init(config); err != nil { + return nil, fmt.Errorf("failed to initialize embedding provider") + } + + return provider, nil +} + +// Mode returns the processing mode for this policy +func (p *SemanticToolFilteringPolicy) Mode() policy.ProcessingMode { + return policy.ProcessingMode{ + RequestHeaderMode: policy.HeaderModeSkip, + RequestBodyMode: policy.BodyModeBuffer, // Need to read and modify request body + ResponseHeaderMode: policy.HeaderModeSkip, + ResponseBodyMode: policy.BodyModeSkip, + } +} + +// extractUserQueryFromText extracts user query from text content using tags +func extractUserQueryFromText(content string) (string, error) { + startTag := "" + endTag := "" + + startIdx := strings.Index(content, startTag) + if startIdx == -1 { + return "", fmt.Errorf("user query start tag not found") + } + + // Search for end tag only after the start tag to avoid matching stray earlier + endIdx := strings.Index(content[startIdx+len(startTag):], endTag) + if endIdx == -1 { + return "", fmt.Errorf("user query end tag not found") + } + endIdx += startIdx + len(startTag) + + query := content[startIdx+len(startTag) : endIdx] + return strings.TrimSpace(query), nil +} + +// extractToolsFromText extracts tools from text content using and tags +func extractToolsFromText(content string) ([]TextTool, error) { + var tools []TextTool + + toolNameStartTag := "" + toolNameEndTag := "" + toolDescStartTag := "" + toolDescEndTag := "" + + // Find all tool definitions in the content + searchStart := 0 + for { + // Find tool name + nameStartIdx := strings.Index(content[searchStart:], toolNameStartTag) + if nameStartIdx == -1 { + break + } + nameStartIdx += searchStart + + nameEndIdx := strings.Index(content[nameStartIdx:], toolNameEndTag) + if nameEndIdx == -1 { + return nil, fmt.Errorf("tool name end tag not found for tool starting at position %d", nameStartIdx) + } + nameEndIdx += nameStartIdx + + toolName := strings.TrimSpace(content[nameStartIdx+len(toolNameStartTag) : nameEndIdx]) + + // Find tool description after the name + descSearchStart := nameEndIdx + len(toolNameEndTag) + descStartIdx := strings.Index(content[descSearchStart:], toolDescStartTag) + if descStartIdx == -1 { + return nil, fmt.Errorf("tool description start tag not found for tool '%s'", toolName) + } + descStartIdx += descSearchStart + + descEndIdx := strings.Index(content[descStartIdx:], toolDescEndTag) + if descEndIdx == -1 { + return nil, fmt.Errorf("tool description end tag not found for tool '%s'", toolName) + } + descEndIdx += descStartIdx + + toolDesc := strings.TrimSpace(content[descStartIdx+len(toolDescStartTag) : descEndIdx]) + + tools = append(tools, TextTool{ + Name: toolName, + Description: toolDesc, + StartPos: nameStartIdx, + EndPos: descEndIdx + len(toolDescEndTag), + }) + + // Move search start past this tool + searchStart = descEndIdx + len(toolDescEndTag) + } + + return tools, nil +} + +// rebuildTextWithFilteredTools rebuilds the text content keeping only filtered tools +func rebuildTextWithFilteredTools(originalContent string, allTools []TextTool, filteredToolNames map[string]bool) string { + if len(allTools) == 0 { + return originalContent + } + + // Sort tools by start position in reverse order to process from end to start + // This ensures position calculations remain valid as we remove content + sortedTools := make([]TextTool, len(allTools)) + copy(sortedTools, allTools) + sort.Slice(sortedTools, func(i, j int) bool { + return sortedTools[i].StartPos > sortedTools[j].StartPos + }) + + result := originalContent + + // Remove tools that are not in the filtered list + for _, tool := range sortedTools { + if !filteredToolNames[tool.Name] { + // Remove this tool from the content + result = result[:tool.StartPos] + result[tool.EndPos:] + } + } + + // Clean up any extra blank lines left after removal + result = cleanupWhitespace(result) + + return result +} + +// stripAllTags removes all text-format tags (userq, toolname, tooldescription) from the content. +// Called after filtering so the downstream payload is clean plain text. +func stripAllTags(content string) string { + content = strings.ReplaceAll(content, "", "") + content = strings.ReplaceAll(content, "", "") + content = strings.ReplaceAll(content, "", "") + content = strings.ReplaceAll(content, "", "") + content = strings.ReplaceAll(content, "", "") + content = strings.ReplaceAll(content, "", "") + return cleanupWhitespace(content) +} + +// cleanupWhitespace removes excessive blank lines while preserving original spacing and indentation. +// Only collapses multiple consecutive blank lines (3+ newlines) to a double newline. +// Does NOT modify spaces or trim content to preserve user prompts exactly. +func cleanupWhitespace(content string) string { + // Replace multiple consecutive newlines (3+) with double newline only + for strings.Contains(content, "\n\n\n") { + content = strings.ReplaceAll(content, "\n\n\n", "\n\n") + } + return content +} + +// OnResponse is a no-op for this policy (only modifies requests) +func (p *SemanticToolFilteringPolicy) OnResponse(ctx *policy.ResponseContext, params map[string]interface{}) policy.ResponseAction { + return policy.UpstreamResponseModifications{} +} + +// extractToolDescription extracts description text from a tool definition +func extractToolDescription(tool map[string]interface{}) string { + // Try common fields for tool description + fields := []string{"description", "desc", "summary", "info"} + + for _, field := range fields { + if desc, ok := tool[field].(string); ok && desc != "" { + return desc + } + } + + // If no description field, try to use name + function description + name, _ := tool["name"].(string) + + // Check for function/parameters structure (OpenAI format) + if function, ok := tool["function"].(map[string]interface{}); ok { + if desc, ok := function["description"].(string); ok && desc != "" { + if name != "" { + return fmt.Sprintf("%s: %s", name, desc) + } + return desc + } + } + + // Fallback to just name if available + if name != "" { + return name + } + + return "" +} + +// cosineSimilarity calculates cosine similarity between two embeddings +func cosineSimilarity(a, b []float32) (float64, error) { + if len(a) == 0 || len(b) == 0 { + return 0, fmt.Errorf("embedding vectors cannot be empty") + } + + if len(a) != len(b) { + return 0, fmt.Errorf("embedding dimensions do not match: %d vs %d", len(a), len(b)) + } + + var dot, normA, normB float64 + for i := range a { + dot += float64(a[i] * b[i]) + normA += float64(a[i] * a[i]) + normB += float64(b[i] * b[i]) + } + + if normA == 0 || normB == 0 { + return 0, fmt.Errorf("embedding vector norm is zero") + } + + return dot / (math.Sqrt(normA) * math.Sqrt(normB)), nil +} + +// filterTools filters tools based on selection mode and criteria +func (p *SemanticToolFilteringPolicy) filterTools(toolsWithScores []ToolWithScore) []map[string]interface{} { + // Sort by score in descending order + sort.Slice(toolsWithScores, func(i, j int) bool { + return toolsWithScores[i].Score > toolsWithScores[j].Score + }) + + var filtered []map[string]interface{} + + switch p.selectionMode { + case SelectionModeTopK: + // Select top K tools + limit := p.topK + if limit > len(toolsWithScores) { + limit = len(toolsWithScores) + } + for i := 0; i < limit; i++ { + filtered = append(filtered, toolsWithScores[i].Tool) + } + + case SelectionModeThreshold: + // Select all tools above threshold + for _, item := range toolsWithScores { + if item.Score >= p.threshold { + filtered = append(filtered, item.Tool) + } + } + } + + return filtered +} + +// updateToolsInRequestBody updates the tools array in the request body +func updateToolsInRequestBody(requestBody *map[string]interface{}, toolsPath string, tools []map[string]interface{}) error { + // Remove leading "$." if present + path := strings.TrimPrefix(toolsPath, "$.") + parts := strings.Split(path, ".") + + if len(parts) == 0 { + return fmt.Errorf("invalid toolsPath: %s", toolsPath) + } + + // Handle array index in path, e.g., "tools[0]" + curr := *requestBody + for idx, part := range parts { + // Check if part contains array index, e.g., "tools[0]" + if openIdx := strings.Index(part, "["); openIdx != -1 && strings.HasSuffix(part, "]") { + field := part[:openIdx] + indexStr := part[openIdx+1 : len(part)-1] + index, err := strconv.Atoi(indexStr) + if err != nil { + return fmt.Errorf("invalid array index in path: %s", part) + } + if index < 0 { + return fmt.Errorf("negative array index in path: %s", part) + } + + // If this is the last part, set the value at the array index + if idx == len(parts)-1 { + // Ensure the array exists + arr, ok := curr[field].([]interface{}) + if !ok { + // Create array if not present + arr = make([]interface{}, index+1) + } else if len(arr) <= index { + // Extend array if needed + newArr := make([]interface{}, index+1) + copy(newArr, arr) + arr = newArr + } + arr[index] = tools + curr[field] = arr + return nil + } + + // Not last part, descend into the array element + arr, ok := curr[field].([]interface{}) + if !ok { + // Create array if not present + arr = make([]interface{}, index+1) + curr[field] = arr + } else if len(arr) <= index { + // Extend array if needed + newArr := make([]interface{}, index+1) + copy(newArr, arr) + arr = newArr + curr[field] = arr + } + // If element is nil, create map + if arr[index] == nil { + arr[index] = make(map[string]interface{}) + } + nextMap, ok := arr[index].(map[string]interface{}) + if !ok { + return fmt.Errorf("expected map at array index %d in field %s", index, field) + } + curr = nextMap + continue + } + + // If this is the last part, set the value + if idx == len(parts)-1 { + curr[part] = tools + return nil + } + + // If the next level doesn't exist, create it as a map + next, ok := curr[part] + if !ok { + newMap := make(map[string]interface{}) + curr[part] = newMap + curr = newMap + continue + } + + // If the next level is a map, descend into it + nextMap, ok := next.(map[string]interface{}) + if !ok { + return fmt.Errorf("expected map at path %s but found %T", part, next) + } + curr = nextMap + } + + return nil +}