Skip to content

Commit 363b82a

Browse files
authored
vertexai[patch]: support code_execution (#959)
1 parent 2f862d5 commit 363b82a

File tree

3 files changed

+143
-2
lines changed

3 files changed

+143
-2
lines changed

libs/vertexai/langchain_google_vertexai/chat_models.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@
114114
from google.cloud.aiplatform_v1beta1.types import (
115115
Blob,
116116
Candidate,
117+
CodeExecutionResult,
117118
Part,
118119
HarmCategory,
119120
Content,
121+
ExecutableCode,
120122
FileData,
121123
FunctionCall,
122124
FunctionResponse,
@@ -242,6 +244,29 @@ def _convert_to_prompt(part: Union[str, Dict]) -> Optional[Part]:
242244
return Part(text=part["text"])
243245
else:
244246
return None
247+
if part["type"] == "executable_code":
248+
if "executable_code" not in part or "language" not in part:
249+
raise ValueError(
250+
"Executable code part must have 'code' and 'language' keys, got "
251+
f"{part}"
252+
)
253+
return Part(
254+
executable_code=ExecutableCode(
255+
language=part["language"], code=part["executable_code"]
256+
)
257+
)
258+
if part["type"] == "code_execution_result":
259+
if "code_execution_result" not in part or "outcome" not in part:
260+
raise ValueError(
261+
"Code execution result part must have 'code_execution_result' and "
262+
f"'outcome' keys, got {part}"
263+
)
264+
return Part(
265+
code_execution_result=CodeExecutionResult(
266+
output=part["code_execution_result"], outcome=part["outcome"]
267+
)
268+
)
269+
245270
if is_data_content_block(part):
246271
# LangChain standard format
247272
if part["type"] == "image" and part["source_type"] == "url":
@@ -542,7 +567,7 @@ def _parse_response_candidate(
542567
def _parse_response_candidate(
543568
response_candidate: "Candidate", streaming: bool = False
544569
) -> AIMessage:
545-
content: Union[None, str, List[str]] = None
570+
content: Union[None, str, List[Union[str, dict[str, Any]]]] = None
546571
additional_kwargs = {}
547572
tool_calls = []
548573
invalid_tool_calls = []
@@ -610,6 +635,44 @@ def _parse_response_candidate(
610635
error=str(e),
611636
)
612637
)
638+
if hasattr(part, "executable_code") and part.executable_code is not None:
639+
if part.executable_code.code and part.executable_code.language:
640+
code_message = {
641+
"type": "executable_code",
642+
"executable_code": part.executable_code.code,
643+
"language": part.executable_code.language,
644+
}
645+
if not content:
646+
content = [code_message]
647+
elif isinstance(content, str):
648+
content = [content, code_message]
649+
elif isinstance(content, list):
650+
content.append(code_message)
651+
else:
652+
raise Exception("Unexpected content type")
653+
654+
if (
655+
hasattr(part, "code_execution_result")
656+
and part.code_execution_result is not None
657+
):
658+
if part.code_execution_result.output and part.code_execution_result.outcome:
659+
execution_result = {
660+
"type": "code_execution_result",
661+
# Name output -> code_execution_result for consistency with
662+
# langchain-google-genai
663+
"code_execution_result": part.code_execution_result.output,
664+
"outcome": part.code_execution_result.outcome,
665+
}
666+
667+
if not content:
668+
content = [execution_result]
669+
elif isinstance(content, str):
670+
content = [content, execution_result]
671+
elif isinstance(content, list):
672+
content.append(execution_result)
673+
else:
674+
raise Exception("Unexpected content type")
675+
613676
if content is None:
614677
content = ""
615678

@@ -896,16 +959,30 @@ class GetPopulation(BaseModel):
896959
897960
See ``ChatVertexAI.bind_tools()`` method for more.
898961
899-
Use Search with Gemini 2:
962+
Built-in search:
900963
.. code-block:: python
901964
902965
from google.cloud.aiplatform_v1beta1.types import Tool as VertexTool
966+
from langchain_google_vertexai import ChatVertexAI
967+
903968
llm = ChatVertexAI(model="gemini-2.0-flash-exp")
904969
resp = llm.invoke(
905970
"When is the next total solar eclipse in US?",
906971
tools=[VertexTool(google_search={})],
907972
)
908973
974+
Built-in code execution:
975+
.. code-block:: python
976+
977+
from google.cloud.aiplatform_v1beta1.types import Tool as VertexTool
978+
from langchain_google_vertexai import ChatVertexAI
979+
980+
llm = ChatVertexAI(model="gemini-2.0-flash-exp")
981+
resp = llm.invoke(
982+
"What is 3^3?",
983+
tools=[VertexTool(code_execution={})],
984+
)
985+
909986
Structured output:
910987
.. code-block:: python
911988

libs/vertexai/langchain_google_vertexai/functions_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from langchain_core.utils.json_schema import dereference_refs
3434
from pydantic import BaseModel
35+
from typing_extensions import NotRequired
3536

3637
logger = logging.getLogger(__name__)
3738

@@ -49,13 +50,15 @@
4950
]
5051
_GoogleSearchLike = Union[gapic.Tool.GoogleSearch, Dict[str, Any]]
5152
_RetrievalLike = Union[gapic.Retrieval, Dict[str, Any]]
53+
_CodeExecutionLike = Union[gapic.Tool.CodeExecution, Dict[str, Any]]
5254

5355

5456
class _ToolDictLike(TypedDict):
5557
function_declarations: Optional[List[_FunctionDeclarationLike]]
5658
google_search_retrieval: Optional[_GoogleSearchRetrievalLike]
5759
google_search: Optional[_GoogleSearchLike]
5860
retrieval: Optional[_RetrievalLike]
61+
code_execution: NotRequired[_CodeExecutionLike]
5962

6063

6164
_ToolType = Union[gapic.Tool, vertexai.Tool, _ToolDictLike, _FunctionDeclarationLike]
@@ -306,6 +309,8 @@ def _format_to_gapic_tool(tools: _ToolsType) -> gapic.Tool:
306309
gapic_tool.function_declarations.extend(rt.function_declarations)
307310
if "google_search" in rt:
308311
gapic_tool.google_search = rt.google_search
312+
if "code_execution" in rt:
313+
gapic_tool.code_execution = rt.code_execution
309314
elif isinstance(tool, dict):
310315
# not _ToolDictLike
311316
if not any(
@@ -315,6 +320,7 @@ def _format_to_gapic_tool(tools: _ToolsType) -> gapic.Tool:
315320
"google_search_retrieval",
316321
"google_search",
317322
"retrieval",
323+
"code_execution",
318324
]
319325
):
320326
fd = _format_to_gapic_function_declaration(tool)
@@ -345,6 +351,10 @@ def _format_to_gapic_tool(tools: _ToolsType) -> gapic.Tool:
345351
)
346352
if "retrieval" in tool:
347353
gapic_tool.retrieval = gapic.Retrieval(tool["retrieval"])
354+
if "code_execution" in tool:
355+
gapic_tool.code_execution = gapic.Tool.CodeExecution(
356+
tool["code_execution"]
357+
)
348358
else:
349359
fd = _format_to_gapic_function_declaration(tool)
350360
gapic_tool.function_declarations.append(fd)

libs/vertexai/tests/integration_tests/test_chat_models.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,3 +1390,57 @@ class People(BaseModel):
13901390
response = llm_with_tools.invoke("Chester, no hair color provided.")
13911391
assert isinstance(response, AIMessage)
13921392
assert response.tool_calls[0]["name"] == "People"
1393+
1394+
1395+
def test_search_builtin() -> None:
1396+
llm = ChatVertexAI(model=_DEFAULT_MODEL_NAME).bind_tools([{"google_search": {}}])
1397+
input_message = {
1398+
"role": "user",
1399+
"content": "What is today's news?",
1400+
}
1401+
response = llm.invoke([input_message])
1402+
assert "grounding_metadata" in response.response_metadata
1403+
1404+
# Test streaming
1405+
full: Optional[BaseMessageChunk] = None
1406+
for chunk in llm.stream([input_message]):
1407+
assert isinstance(chunk, AIMessageChunk)
1408+
full = chunk if full is None else full + chunk
1409+
assert isinstance(full, AIMessageChunk)
1410+
assert "grounding_metadata" in full.response_metadata
1411+
1412+
# Test we can process chat history
1413+
next_message = {
1414+
"role": "user",
1415+
"content": "Tell me more about that last story.",
1416+
}
1417+
_ = llm.invoke([input_message, full, next_message])
1418+
1419+
1420+
def test_code_execution_builtin() -> None:
1421+
llm = ChatVertexAI(model=_DEFAULT_MODEL_NAME).bind_tools([{"code_execution": {}}])
1422+
input_message = {
1423+
"role": "user",
1424+
"content": "What is 3^3?",
1425+
}
1426+
response = llm.invoke([input_message])
1427+
content_blocks = [block for block in response.content if isinstance(block, dict)]
1428+
expected_block_types = {"executable_code", "code_execution_result"}
1429+
assert set(block.get("type") for block in content_blocks) == expected_block_types
1430+
1431+
# Test streaming
1432+
full: Optional[BaseMessageChunk] = None
1433+
for chunk in llm.stream([input_message]):
1434+
assert isinstance(chunk, AIMessageChunk)
1435+
full = chunk if full is None else full + chunk
1436+
assert isinstance(full, AIMessageChunk)
1437+
content_blocks = [block for block in full.content if isinstance(block, dict)]
1438+
expected_block_types = {"executable_code", "code_execution_result"}
1439+
assert set(block.get("type") for block in content_blocks) == expected_block_types
1440+
1441+
# Test we can process chat history
1442+
next_message = {
1443+
"role": "user",
1444+
"content": "Can you add some comments to the code?",
1445+
}
1446+
_ = llm.invoke([input_message, full, next_message])

0 commit comments

Comments
 (0)