Skip to content

Commit ad4fadb

Browse files
Fix resolving of references in ChatVertexAI.with_structured_output (#843)
1 parent 88a49ea commit ad4fadb

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

libs/vertexai/langchain_google_vertexai/chat_models.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,12 +1150,12 @@ class Joke(BaseModel):
11501150

11511151
logprobs: Union[bool, int] = False
11521152
"""Whether to return logprobs as part of AIMessage.response_metadata.
1153-
1154-
If False, don't return logprobs. If True, return logprobs for top candidate.
1153+
1154+
If False, don't return logprobs. If True, return logprobs for top candidate.
11551155
If int, return logprobs for top ``logprobs`` candidates.
1156-
1156+
11571157
**NOTE**: As of 10.28.24 this is only supported for gemini-1.5-flash models.
1158-
1158+
11591159
.. versionadded: 2.0.6
11601160
"""
11611161
labels: Optional[Dict[str, str]] = None
@@ -2025,41 +2025,43 @@ class AnswerWithJustification(BaseModel):
20252025
parser: OutputParserLike
20262026

20272027
if method == "json_mode":
2028-
schema_is_typeddict = is_typeddict(schema)
2029-
if isinstance(schema, type) and not schema_is_typeddict:
2030-
# TODO: This gets the json schema of a pydantic model. It fails for
2031-
# nested models because the generated schema contains $refs that the
2032-
# gemini api doesn't support. We can implement a postprocessing function
2033-
# that takes care of this if necessary.
2028+
if isinstance(schema, type) and is_basemodel_subclass(schema):
20342029
if issubclass(schema, BaseModelV1):
20352030
schema_json = schema.schema()
20362031
else:
20372032
schema_json = schema.model_json_schema()
2038-
schema_json = replace_defs_in_schema(schema_json)
20392033
parser = PydanticOutputParser(pydantic_object=schema)
20402034
else:
2041-
if schema_is_typeddict:
2035+
if is_typeddict(schema):
20422036
schema_json = convert_to_json_schema(schema)
2037+
elif isinstance(schema, dict):
2038+
schema_json = schema
20432039
else:
2044-
schema_json = cast(dict, schema)
2040+
raise ValueError(f"Unsupported schema type {type(schema)}")
20452041
parser = JsonOutputParser()
2042+
2043+
# Resolve refs in schema because they are not supported
2044+
# by the Gemini API.
2045+
schema_json = replace_defs_in_schema(schema_json)
2046+
20462047
llm = self.bind(
20472048
response_mime_type="application/json",
20482049
response_schema=schema_json,
20492050
ls_structured_output_format={
20502051
"kwargs": {"method": method},
2051-
"schema": convert_to_json_schema(schema),
2052+
"schema": schema_json,
20522053
},
20532054
)
2054-
20552055
else:
20562056
tool_name = _get_tool_name(schema)
20572057
if isinstance(schema, type) and is_basemodel_subclass(schema):
20582058
parser = PydanticToolsParser(tools=[schema], first_tool_only=True)
2059-
else:
2059+
elif is_typeddict(schema) or isinstance(schema, dict):
20602060
parser = JsonOutputKeyToolsParser(
20612061
key_name=tool_name, first_tool_only=True
20622062
)
2063+
else:
2064+
raise ValueError(f"Unsupported schema type {type(schema)}")
20632065
tool_choice = tool_name if self._is_gemini_advanced else None
20642066

20652067
try:

libs/vertexai/langchain_google_vertexai/functions_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _format_json_schema_to_gapic(
112112
"""Format a JSON schema from a Pydantic V2 BaseModel to gapic."""
113113
converted_schema: Dict[str, Any] = {}
114114
for key, value in schema.items():
115-
if key == "definitions":
115+
if key == "$defs":
116116
continue
117117
elif key == "items":
118118
converted_schema["items"] = _format_json_schema_to_gapic(
@@ -157,7 +157,10 @@ def _format_json_schema_to_gapic(
157157
def _dict_to_gapic_schema(
158158
schema: Dict[str, Any], pydantic_version: str = "v1"
159159
) -> gapic.Schema:
160+
# Resolve refs in schema because $refs and $defs are not supported
161+
# by the Gemini API.
160162
dereferenced_schema = dereference_refs(schema)
163+
161164
if pydantic_version == "v1":
162165
formatted_schema = _format_json_schema_to_gapic_v1(dereferenced_schema)
163166
else:

0 commit comments

Comments
 (0)