Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions internal/apischema/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -1071,8 +1071,9 @@ type StreamOptions struct {
type ToolType string

const (
ToolTypeFunction ToolType = "function"
ToolTypeImageGeneration ToolType = "image_generation"
ToolTypeFunction ToolType = "function"
ToolTypeImageGeneration ToolType = "image_generation"
ToolTypeEnterpriseWebSearch ToolType = "enterprise_search"
)

type Tool struct {
Expand Down Expand Up @@ -1314,6 +1315,9 @@ type ChatCompletionResponseChoiceMessage struct {
// List of ratings for the safety of a response candidate. There is at most one rating per category.
// https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1/GenerateContentResponse#SafetyRating
SafetyRatings []*genai.SafetyRating `json:"safety_ratings,omitempty"`

// GroundingMetadata specifies sources used to ground generated content.
GroundingMetadata *genai.GroundingMetadata `json:"grounding_metadata,omitempty"`
}

// URLCitation contains citation information for web search results.
Expand Down
127 changes: 127 additions & 0 deletions internal/apischema/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,133 @@ func TestChatCompletionRequest(t *testing.T) {
},
},
},
{
name: "enterprise search tool",
jsonStr: `{
"model": "gemini-1.5-pro",
"messages": [
{
"role": "user",
"content": "Hello with enterprise search!"
}
],
"tools": [
{
"type": "enterprise_search"
}
]
}`,
expected: &ChatCompletionRequest{
Model: "gemini-1.5-pro",
Messages: []ChatCompletionMessageParamUnion{
{
OfUser: &ChatCompletionUserMessageParam{
Role: ChatMessageRoleUser,
Content: StringOrUserRoleContentUnion{Value: "Hello with enterprise search!"},
},
},
},
Tools: []Tool{
{
Type: ToolTypeEnterpriseWebSearch,
},
},
},
},
{
name: "mixed function and enterprise search tools",
jsonStr: `{
"model": "gemini-1.5-pro",
"messages": [
{
"role": "user",
"content": "Mixed tools test"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather"
}
},
{
"type": "enterprise_search"
}
]
}`,
expected: &ChatCompletionRequest{
Model: "gemini-1.5-pro",
Messages: []ChatCompletionMessageParamUnion{
{
OfUser: &ChatCompletionUserMessageParam{
Role: ChatMessageRoleUser,
Content: StringOrUserRoleContentUnion{Value: "Mixed tools test"},
},
},
},
Tools: []Tool{
{
Type: ToolTypeFunction,
Function: &FunctionDefinition{
Name: "get_weather",
Description: "Get current weather",
},
},
{
Type: ToolTypeEnterpriseWebSearch,
},
},
},
},
{
name: "enterprise search with vendor fields",
jsonStr: `{
"model": "gemini-1.5-pro",
"messages": [
{
"role": "user",
"content": "Combined enterprise search and safety settings"
}
],
"tools": [
{
"type": "enterprise_search"
}
],
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
}
]
}`,
expected: &ChatCompletionRequest{
Model: "gemini-1.5-pro",
Messages: []ChatCompletionMessageParamUnion{
{
OfUser: &ChatCompletionUserMessageParam{
Role: ChatMessageRoleUser,
Content: StringOrUserRoleContentUnion{Value: "Combined enterprise search and safety settings"},
},
},
},
Tools: []Tool{
{
Type: ToolTypeEnterpriseWebSearch,
},
},
GCPVertexAIVendorFields: &GCPVertexAIVendorFields{
SafetySettings: []*genai.SafetySetting{
{
Category: genai.HarmCategoryHarassment,
Threshold: genai.HarmBlockThresholdBlockOnlyHigh,
},
},
},
},
},
}

for _, tc := range testCases {
Expand Down
28 changes: 26 additions & 2 deletions internal/translator/gemini_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ func openAIToolsToGeminiTools(openaiTools []openai.Tool, parametersJSONSchemaAva
if len(openaiTools) == 0 {
return nil, nil
}

var genaiTools []genai.Tool

var functionDecls []*genai.FunctionDeclaration

for _, tool := range openaiTools {
Expand Down Expand Up @@ -389,14 +392,27 @@ func openAIToolsToGeminiTools(openaiTools []openai.Tool, parametersJSONSchemaAva
}
case openai.ToolTypeImageGeneration:
return nil, fmt.Errorf("tool-type image generation not supported yet when translating OpenAI req to Gemini")
case openai.ToolTypeEnterpriseWebSearch:
genaiTools = append(genaiTools, genai.Tool{
EnterpriseWebSearch: &genai.EnterpriseWebSearch{},
})
default:
return nil, fmt.Errorf("unsupported tool type: %s", tool.Type)
}
}
if len(functionDecls) == 0 {
// Only return nil if there are no tools at all (neither function declarations nor other tools)
if len(functionDecls) == 0 && len(genaiTools) == 0 {
return nil, nil
}
return []genai.Tool{{FunctionDeclarations: functionDecls}}, nil

// Only append function declarations if there are any
if len(functionDecls) > 0 {
genaiTools = append(genaiTools, genai.Tool{
FunctionDeclarations: functionDecls,
})
}

return genaiTools, nil
}

// openAIToolChoiceToGeminiToolConfig converts OpenAI tool_choice to Gemini ToolConfig.
Expand Down Expand Up @@ -682,6 +698,14 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode
choice.Message.SafetyRatings = candidate.SafetyRatings
}

if candidate.GroundingMetadata != nil {
if choice.Message.Role == "" {
choice.Message.Role = openai.ChatMessageRoleAssistant
}

choice.Message.GroundingMetadata = candidate.GroundingMetadata
}

// Handle logprobs if available.
if candidate.LogprobsResult != nil {
choice.Logprobs = geminiLogprobsToOpenAILogprobs(*candidate.LogprobsResult)
Expand Down
52 changes: 52 additions & 0 deletions internal/translator/gemini_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,58 @@ func TestOpenAIToolsToGeminiTools(t *testing.T) {
parametersJSONSchemaAvailable: false,
expected: nil,
},
{
name: "enterprise search tool only",
openaiTools: []openai.Tool{
{
Type: openai.ToolTypeEnterpriseWebSearch,
},
},
parametersJSONSchemaAvailable: false,
expected: []genai.Tool{
{
EnterpriseWebSearch: &genai.EnterpriseWebSearch{},
},
},
},
{
name: "mixed function and enterprise search tools",
openaiTools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
Function: &openai.FunctionDefinition{
Name: "get_weather",
Description: "Get current weather",
Parameters: funcParams,
},
},
{
Type: openai.ToolTypeEnterpriseWebSearch,
},
},
parametersJSONSchemaAvailable: false,
expected: []genai.Tool{
{
EnterpriseWebSearch: &genai.EnterpriseWebSearch{},
},
{
FunctionDeclarations: []*genai.FunctionDeclaration{
{
Name: "get_weather",
Description: "Get current weather",
Parameters: &genai.Schema{
Type: "object",
Properties: map[string]*genai.Schema{
"a": {Type: "integer"},
"b": {Type: "integer"},
},
Required: []string{"a", "b"},
},
},
},
},
},
},
}

for _, tc := range tests {
Expand Down
50 changes: 50 additions & 0 deletions internal/translator/openai_gcpvertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,28 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T)
}
}`)

wantBdyWithEnterpriseWebSearch := []byte(`{
"contents": [
{
"parts": [
{
"text": "Test with web grounding for enterprise"
}
],
"role": "user"
}
],
"tools": [
{
"enterpriseWebSearch": {}
}
],
"generation_config": {
"maxOutputTokens": 1024,
"temperature": 0.7
}
}`)

tests := []struct {
name string
modelNameOverride internalapi.ModelNameOverride
Expand Down Expand Up @@ -739,6 +761,34 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T)
},
wantBody: wantBdyWithGuidedRegex,
},
{
name: "Request with gcp web grounding for enterprise",
input: openai.ChatCompletionRequest{
Model: "gemini-1.5-pro",
Temperature: ptr.To(0.7),
MaxTokens: ptr.To(int64(1024)),
Messages: []openai.ChatCompletionMessageParamUnion{
{
OfUser: &openai.ChatCompletionUserMessageParam{
Role: openai.ChatMessageRoleUser,
Content: openai.StringOrUserRoleContentUnion{Value: "Test with web grounding for enterprise"},
},
},
},
Tools: []openai.Tool{
{
Type: "enterprise_search",
},
},
},
onRetry: false,
wantError: false,
wantHeaderMut: []internalapi.Header{
{":path", "publishers/google/models/gemini-1.5-pro:generateContent"},
{"content-length", "190"},
},
wantBody: wantBdyWithEnterpriseWebSearch,
},
}

for _, tc := range tests {
Expand Down
Loading