diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go index d1dd9c3a7..a42a271df 100644 --- a/internal/apischema/openai/openai.go +++ b/internal/apischema/openai/openai.go @@ -1071,8 +1071,9 @@ type StreamOptions struct { type ToolType string const ( - ToolTypeFunction ToolType = "function" - ToolTypeImageGeneration ToolType = "image_generation" + ToolTypeFunction ToolType = "function" + ToolTypeImageGeneration ToolType = "image_generation" + ToolTypeEnterpriseWebSearch ToolType = "enterprise_search" ) type Tool struct { @@ -1314,6 +1315,9 @@ type ChatCompletionResponseChoiceMessage struct { // List of ratings for the safety of a response candidate. There is at most one rating per category. // https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1/GenerateContentResponse#SafetyRating SafetyRatings []*genai.SafetyRating `json:"safety_ratings,omitempty"` + + // GroundingMetadata specifies sources used to ground generated content. + GroundingMetadata *genai.GroundingMetadata `json:"grounding_metadata,omitempty"` } // URLCitation contains citation information for web search results. diff --git a/internal/apischema/openai/openai_test.go b/internal/apischema/openai/openai_test.go index 9dfd0261c..284cf0c1a 100644 --- a/internal/apischema/openai/openai_test.go +++ b/internal/apischema/openai/openai_test.go @@ -1193,6 +1193,133 @@ func TestChatCompletionRequest(t *testing.T) { }, }, }, + { + name: "enterprise search tool", + jsonStr: `{ + "model": "gemini-1.5-pro", + "messages": [ + { + "role": "user", + "content": "Hello with enterprise search!" + } + ], + "tools": [ + { + "type": "enterprise_search" + } + ] + }`, + expected: &ChatCompletionRequest{ + Model: "gemini-1.5-pro", + Messages: []ChatCompletionMessageParamUnion{ + { + OfUser: &ChatCompletionUserMessageParam{ + Role: ChatMessageRoleUser, + Content: StringOrUserRoleContentUnion{Value: "Hello with enterprise search!"}, + }, + }, + }, + Tools: []Tool{ + { + Type: ToolTypeEnterpriseWebSearch, + }, + }, + }, + }, + { + name: "mixed function and enterprise search tools", + jsonStr: `{ + "model": "gemini-1.5-pro", + "messages": [ + { + "role": "user", + "content": "Mixed tools test" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather" + } + }, + { + "type": "enterprise_search" + } + ] + }`, + expected: &ChatCompletionRequest{ + Model: "gemini-1.5-pro", + Messages: []ChatCompletionMessageParamUnion{ + { + OfUser: &ChatCompletionUserMessageParam{ + Role: ChatMessageRoleUser, + Content: StringOrUserRoleContentUnion{Value: "Mixed tools test"}, + }, + }, + }, + Tools: []Tool{ + { + Type: ToolTypeFunction, + Function: &FunctionDefinition{ + Name: "get_weather", + Description: "Get current weather", + }, + }, + { + Type: ToolTypeEnterpriseWebSearch, + }, + }, + }, + }, + { + name: "enterprise search with vendor fields", + jsonStr: `{ + "model": "gemini-1.5-pro", + "messages": [ + { + "role": "user", + "content": "Combined enterprise search and safety settings" + } + ], + "tools": [ + { + "type": "enterprise_search" + } + ], + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_ONLY_HIGH" + } + ] + }`, + expected: &ChatCompletionRequest{ + Model: "gemini-1.5-pro", + Messages: []ChatCompletionMessageParamUnion{ + { + OfUser: &ChatCompletionUserMessageParam{ + Role: ChatMessageRoleUser, + Content: StringOrUserRoleContentUnion{Value: "Combined enterprise search and safety settings"}, + }, + }, + }, + Tools: []Tool{ + { + Type: ToolTypeEnterpriseWebSearch, + }, + }, + GCPVertexAIVendorFields: &GCPVertexAIVendorFields{ + SafetySettings: []*genai.SafetySetting{ + { + Category: genai.HarmCategoryHarassment, + Threshold: genai.HarmBlockThresholdBlockOnlyHigh, + }, + }, + }, + }, + }, } for _, tc := range testCases { diff --git a/internal/translator/gemini_helper.go b/internal/translator/gemini_helper.go index c79d157b9..840673624 100644 --- a/internal/translator/gemini_helper.go +++ b/internal/translator/gemini_helper.go @@ -358,6 +358,9 @@ func openAIToolsToGeminiTools(openaiTools []openai.Tool, parametersJSONSchemaAva if len(openaiTools) == 0 { return nil, nil } + + var genaiTools []genai.Tool + var functionDecls []*genai.FunctionDeclaration for _, tool := range openaiTools { @@ -389,14 +392,27 @@ func openAIToolsToGeminiTools(openaiTools []openai.Tool, parametersJSONSchemaAva } case openai.ToolTypeImageGeneration: return nil, fmt.Errorf("tool-type image generation not supported yet when translating OpenAI req to Gemini") + case openai.ToolTypeEnterpriseWebSearch: + genaiTools = append(genaiTools, genai.Tool{ + EnterpriseWebSearch: &genai.EnterpriseWebSearch{}, + }) default: return nil, fmt.Errorf("unsupported tool type: %s", tool.Type) } } - if len(functionDecls) == 0 { + // Only return nil if there are no tools at all (neither function declarations nor other tools) + if len(functionDecls) == 0 && len(genaiTools) == 0 { return nil, nil } - return []genai.Tool{{FunctionDeclarations: functionDecls}}, nil + + // Only append function declarations if there are any + if len(functionDecls) > 0 { + genaiTools = append(genaiTools, genai.Tool{ + FunctionDeclarations: functionDecls, + }) + } + + return genaiTools, nil } // openAIToolChoiceToGeminiToolConfig converts OpenAI tool_choice to Gemini ToolConfig. @@ -682,6 +698,14 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode choice.Message.SafetyRatings = candidate.SafetyRatings } + if candidate.GroundingMetadata != nil { + if choice.Message.Role == "" { + choice.Message.Role = openai.ChatMessageRoleAssistant + } + + choice.Message.GroundingMetadata = candidate.GroundingMetadata + } + // Handle logprobs if available. if candidate.LogprobsResult != nil { choice.Logprobs = geminiLogprobsToOpenAILogprobs(*candidate.LogprobsResult) diff --git a/internal/translator/gemini_helper_test.go b/internal/translator/gemini_helper_test.go index 58e410eff..510785960 100644 --- a/internal/translator/gemini_helper_test.go +++ b/internal/translator/gemini_helper_test.go @@ -1420,6 +1420,58 @@ func TestOpenAIToolsToGeminiTools(t *testing.T) { parametersJSONSchemaAvailable: false, expected: nil, }, + { + name: "enterprise search tool only", + openaiTools: []openai.Tool{ + { + Type: openai.ToolTypeEnterpriseWebSearch, + }, + }, + parametersJSONSchemaAvailable: false, + expected: []genai.Tool{ + { + EnterpriseWebSearch: &genai.EnterpriseWebSearch{}, + }, + }, + }, + { + name: "mixed function and enterprise search tools", + openaiTools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get current weather", + Parameters: funcParams, + }, + }, + { + Type: openai.ToolTypeEnterpriseWebSearch, + }, + }, + parametersJSONSchemaAvailable: false, + expected: []genai.Tool{ + { + EnterpriseWebSearch: &genai.EnterpriseWebSearch{}, + }, + { + FunctionDeclarations: []*genai.FunctionDeclaration{ + { + Name: "get_weather", + Description: "Get current weather", + Parameters: &genai.Schema{ + Type: "object", + Properties: map[string]*genai.Schema{ + "a": {Type: "integer"}, + "b": {Type: "integer"}, + }, + Required: []string{"a", "b"}, + }, + }, + }, + }, + }, + }, } for _, tc := range tests { diff --git a/internal/translator/openai_gcpvertexai_test.go b/internal/translator/openai_gcpvertexai_test.go index a60b1bcf9..488cc39e2 100644 --- a/internal/translator/openai_gcpvertexai_test.go +++ b/internal/translator/openai_gcpvertexai_test.go @@ -335,6 +335,28 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T) } }`) + wantBdyWithEnterpriseWebSearch := []byte(`{ + "contents": [ + { + "parts": [ + { + "text": "Test with web grounding for enterprise" + } + ], + "role": "user" + } + ], + "tools": [ + { + "enterpriseWebSearch": {} + } + ], + "generation_config": { + "maxOutputTokens": 1024, + "temperature": 0.7 + } +}`) + tests := []struct { name string modelNameOverride internalapi.ModelNameOverride @@ -739,6 +761,34 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T) }, wantBody: wantBdyWithGuidedRegex, }, + { + name: "Request with gcp web grounding for enterprise", + input: openai.ChatCompletionRequest{ + Model: "gemini-1.5-pro", + Temperature: ptr.To(0.7), + MaxTokens: ptr.To(int64(1024)), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{Value: "Test with web grounding for enterprise"}, + }, + }, + }, + Tools: []openai.Tool{ + { + Type: "enterprise_search", + }, + }, + }, + onRetry: false, + wantError: false, + wantHeaderMut: []internalapi.Header{ + {":path", "publishers/google/models/gemini-1.5-pro:generateContent"}, + {"content-length", "190"}, + }, + wantBody: wantBdyWithEnterpriseWebSearch, + }, } for _, tc := range tests {