Skip to content

Commit 396ed6f

Browse files
authored
Add GCS image support in ChatAnthropicVertex model (#808)
1 parent e390510 commit 396ed6f

File tree

3 files changed

+49
-19
lines changed

3 files changed

+49
-19
lines changed

libs/vertexai/langchain_google_vertexai/_anthropic_utils.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from langchain_core.utils.function_calling import convert_to_openai_tool
3232
from pydantic import BaseModel
3333

34+
from langchain_google_vertexai._image_utils import image_bytes_to_b64_string
35+
from langchain_google_vertexai._utils import load_image_from_gcs
36+
3437
if TYPE_CHECKING:
3538
from anthropic.types import (
3639
RawMessageStreamEvent, # type: ignore[unused-ignore, import-not-found]
@@ -44,7 +47,7 @@
4447
}
4548

4649

47-
def _format_image(image_url: str) -> Dict:
50+
def _format_image(image_url: str, project: Optional[str]) -> Dict:
4851
"""Formats a message image to a dict for anthropic api."""
4952
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
5053
match = re.match(regex, image_url)
@@ -60,6 +63,16 @@ def _format_image(image_url: str) -> Dict:
6063
"type": "url",
6164
"url": image_url,
6265
}
66+
elif image_url.startswith("gs://"):
67+
# Gets image and encodes to base64.
68+
image = load_image_from_gcs(image_url, project)
69+
return {
70+
"type": "base64",
71+
"media_type": image._mime_type(),
72+
"data": image_bytes_to_b64_string(
73+
image.data(), "ascii", image._mime_type().split("/")[-1]
74+
),
75+
}
6376
else:
6477
raise ValueError(
6578
"Anthropic only supports base64-encoded images and urls currently."
@@ -83,7 +96,9 @@ def _format_text_content(text: str) -> Dict[str, Union[str, Dict[str, Any]]]:
8396
return content
8497

8598

86-
def _format_message_anthropic(message: Union[HumanMessage, AIMessage, SystemMessage]):
99+
def _format_message_anthropic(
100+
message: Union[HumanMessage, AIMessage, SystemMessage], project: Optional[str]
101+
):
87102
"""Format a message for Anthropic API.
88103
89104
Args:
@@ -152,7 +167,7 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage, SystemMess
152167

153168
if block["type"] == "image_url":
154169
# convert format
155-
source = _format_image(block["image_url"]["url"])
170+
source = _format_image(block["image_url"]["url"], project)
156171
content.append({"type": "image", "source": source})
157172
continue
158173

@@ -183,6 +198,7 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage, SystemMess
183198

184199
def _format_messages_anthropic(
185200
messages: List[BaseMessage],
201+
project: Optional[str],
186202
) -> Tuple[Optional[Dict[str, Any]], List[Dict]]:
187203
"""Formats messages for anthropic."""
188204
system_messages: Optional[Dict[str, Any]] = None
@@ -193,12 +209,12 @@ def _format_messages_anthropic(
193209
if message.type == "system":
194210
if i != 0:
195211
raise ValueError("System message must be at beginning of message list.")
196-
fm = _format_message_anthropic(message)
212+
fm = _format_message_anthropic(message, project)
197213
if fm:
198214
system_messages = fm
199215
continue
200216

201-
fm = _format_message_anthropic(message)
217+
fm = _format_message_anthropic(message, project)
202218
if not fm:
203219
continue
204220
formatted_messages.append(fm)

libs/vertexai/langchain_google_vertexai/model_garden.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ def _format_params(
248248
stop: Optional[List[str]] = None,
249249
**kwargs: Any,
250250
) -> Dict[str, Any]:
251-
system_message, formatted_messages = _format_messages_anthropic(messages)
251+
system_message, formatted_messages = _format_messages_anthropic(
252+
messages, self.project
253+
)
252254
params = self._default_params
253255
params.update(kwargs)
254256
if kwargs.get("model_name"):

libs/vertexai/tests/unit_tests/test_anthropic_utils.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_format_message_anthropic_with_cache_control_in_kwargs():
2727
message = HumanMessage(
2828
content="Hello", additional_kwargs={"cache_control": {"type": "semantic"}}
2929
)
30-
result = _format_message_anthropic(message)
30+
result = _format_message_anthropic(message, project="test-project")
3131
assert result == {
3232
"role": "user",
3333
"content": [
@@ -43,7 +43,7 @@ def test_format_message_anthropic_with_cache_control_in_block():
4343
{"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}}
4444
]
4545
)
46-
result = _format_message_anthropic(message)
46+
result = _format_message_anthropic(message, project="test-project")
4747
assert result == {
4848
"role": "user",
4949
"content": [
@@ -61,7 +61,7 @@ def test_format_message_anthropic_with_mixed_blocks():
6161
"Plain text",
6262
]
6363
)
64-
result = _format_message_anthropic(message)
64+
result = _format_message_anthropic(message, project="test-project")
6565
assert result == {
6666
"role": "user",
6767
"content": [
@@ -81,7 +81,9 @@ def test_format_messages_anthropic_with_system_cache_control():
8181
),
8282
HumanMessage(content="Hello"),
8383
]
84-
system_messages, formatted_messages = _format_messages_anthropic(messages)
84+
system_messages, formatted_messages = _format_messages_anthropic(
85+
messages, project="test-project"
86+
)
8587

8688
assert system_messages == [
8789
{
@@ -102,7 +104,7 @@ def test_format_message_anthropic_system():
102104
content="System message",
103105
additional_kwargs={"cache_control": {"type": "ephemeral"}},
104106
)
105-
result = _format_message_anthropic(message)
107+
result = _format_message_anthropic(message, project="test-project")
106108
assert result == [
107109
{
108110
"type": "text",
@@ -124,7 +126,7 @@ def test_format_message_anthropic_system_list():
124126
{"type": "text", "text": "System rule 2"},
125127
]
126128
)
127-
result = _format_message_anthropic(message)
129+
result = _format_message_anthropic(message, project="test-project")
128130
assert result == [
129131
{
130132
"type": "text",
@@ -156,7 +158,7 @@ def test_format_message_anthropic_with_chain_of_thoughts():
156158
},
157159
]
158160
)
159-
result = _format_message_anthropic(message)
161+
result = _format_message_anthropic(message, project="test-project")
160162
assert result == [
161163
{
162164
"type": "text",
@@ -185,7 +187,7 @@ def test_format_message_anthropic_with_image_content():
185187
},
186188
]
187189
)
188-
result = _format_message_anthropic(message)
190+
result = _format_message_anthropic(message, project="test-project")
189191
assert result == [
190192
{
191193
"type": "image",
@@ -211,7 +213,9 @@ def test_format_messages_anthropic_with_system_string():
211213
SystemMessage(content="System message"),
212214
HumanMessage(content="Hello"),
213215
]
214-
system_messages, formatted_messages = _format_messages_anthropic(messages)
216+
system_messages, formatted_messages = _format_messages_anthropic(
217+
messages, project="test-project"
218+
)
215219

216220
assert system_messages == [{"type": "text", "text": "System message"}]
217221

@@ -235,7 +239,9 @@ def test_format_messages_anthropic_with_system_list():
235239
),
236240
HumanMessage(content="Hello"),
237241
]
238-
system_messages, formatted_messages = _format_messages_anthropic(messages)
242+
system_messages, formatted_messages = _format_messages_anthropic(
243+
messages, project="test-project"
244+
)
239245

240246
assert system_messages == [
241247
{
@@ -266,7 +272,9 @@ def test_format_messages_anthropic_with_system_mixed_list():
266272
),
267273
HumanMessage(content="Hello"),
268274
]
269-
system_messages, formatted_messages = _format_messages_anthropic(messages)
275+
system_messages, formatted_messages = _format_messages_anthropic(
276+
messages, project="test-project"
277+
)
270278

271279
assert system_messages == [
272280
{"type": "text", "text": "Plain system rule"},
@@ -308,7 +316,9 @@ def test_format_messages_anthropic_with_mixed_messages():
308316
additional_kwargs={"cache_control": {"type": "semantic"}},
309317
),
310318
]
311-
system_messages, formatted_messages = _format_messages_anthropic(messages)
319+
system_messages, formatted_messages = _format_messages_anthropic(
320+
messages, project="test-project"
321+
)
312322

313323
assert system_messages == [
314324
{
@@ -806,7 +816,9 @@ def test_format_messages_anthropic(
806816
source_history, expected_sm, expected_history
807817
) -> None:
808818
"""Test the original format_messages_anthropic functionality."""
809-
sm, result_history = _format_messages_anthropic(source_history)
819+
sm, result_history = _format_messages_anthropic(
820+
source_history, project="test-project"
821+
)
810822

811823
for result, expected in zip(result_history, expected_history):
812824
assert result == expected

0 commit comments

Comments
 (0)