Skip to content

Commit c1ff72f

Browse files
authored
feat: add async embedding methods and update tests for output dimensionality for Gemini embedding compat (#1031)
2 parents a7033ca + cbcffeb commit c1ff72f

File tree

3 files changed

+174
-40
lines changed

3 files changed

+174
-40
lines changed

libs/genai/langchain_google_genai/embeddings.py

Lines changed: 117 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
GoogleGenerativeAIError,
1818
get_client_info,
1919
)
20-
from langchain_google_genai._genai_extension import build_generative_service
20+
from langchain_google_genai._genai_extension import (
21+
build_generative_async_service,
22+
build_generative_service,
23+
)
2124

2225
_MAX_TOKENS_PER_BATCH = 20000
2326
_DEFAULT_BATCH_SIZE = 100
@@ -29,8 +32,8 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
2932
To use, you must have either:
3033
3134
1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
32-
2. Pass your API key using the google_api_key kwarg
33-
to the GoogleGenerativeAIEmbeddings constructor.
35+
2. Pass your API key using the google_api_key kwarg to the
36+
GoogleGenerativeAIEmbeddings constructor.
3437
3538
Example:
3639
.. code-block:: python
@@ -42,6 +45,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
4245
"""
4346

4447
client: Any = None #: :meta private:
48+
async_client: Any = None #: :meta private:
4549
model: str = Field(
4650
...,
4751
description="The name of the embedding model to use. "
@@ -100,6 +104,13 @@ def validate_environment(self) -> Self:
100104
client_options=self.client_options,
101105
transport=self.transport,
102106
)
107+
self.async_client = build_generative_async_service(
108+
credentials=self.credentials,
109+
api_key=google_api_key,
110+
client_info=client_info,
111+
client_options=self.client_options,
112+
transport=self.transport,
113+
)
103114
return self
104115

105116
@staticmethod
@@ -166,12 +177,12 @@ def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
166177
def _prepare_request(
167178
self,
168179
text: str,
180+
*,
169181
task_type: Optional[str] = None,
170182
title: Optional[str] = None,
171183
output_dimensionality: Optional[int] = None,
172184
) -> EmbedContentRequest:
173185
task_type = self.task_type or task_type or "RETRIEVAL_DOCUMENT"
174-
# https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
175186
request = EmbedContentRequest(
176187
content={"parts": [{"text": text}]},
177188
model=self.model,
@@ -190,16 +201,17 @@ def embed_documents(
190201
titles: Optional[List[str]] = None,
191202
output_dimensionality: Optional[int] = None,
192203
) -> List[List[float]]:
193-
"""Embed a list of strings. Google Generative AI currently
194-
sets a max batch size of 100 strings.
204+
"""Embed a list of strings using the `batch endpoint <https://ai.google.dev/api/embeddings#method:-models.batchembedcontents>`__.
205+
206+
Google Generative AI currently sets a max batch size of 100 strings.
195207
196208
Args:
197209
texts: List[str] The list of strings to embed.
198210
batch_size: [int] The batch size of embeddings to send to the model
199-
task_type: `task_type <https://ai.google.dev/api/rest/v1/TaskType>`__
211+
task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype>`__
200212
titles: An optional list of titles for texts provided.
201-
Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
202-
output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest>`__.
213+
Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
214+
output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings#EmbedContentRequest>`__.
203215
Returns:
204216
List of embeddings, one for each text.
205217
"""
@@ -236,25 +248,26 @@ def embed_documents(
236248
def embed_query(
237249
self,
238250
text: str,
251+
*,
239252
task_type: Optional[str] = None,
240253
title: Optional[str] = None,
241254
output_dimensionality: Optional[int] = None,
242255
) -> List[float]:
243-
"""Embed a text, using the `non-batch endpoint <https://ai.google.dev/api/rest/v1/models/embedContent#EmbedContentRequest>`__.
256+
"""Embed a text, using the `non-batch endpoint <https://ai.google.dev/api/embeddings#method:-models.embedcontent>`__.
244257
245258
Args:
246259
text: The text to embed.
247-
task_type: `task_type <https://ai.google.dev/api/rest/v1/TaskType>`__
260+
task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype>`__
248261
title: An optional title for the text.
249-
Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
250-
output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest>`__.
262+
Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
263+
output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings#EmbedContentRequest>`__.
251264
252265
Returns:
253266
Embedding for the text.
254267
"""
255268
task_type_to_use = task_type if task_type else self.task_type
256269
if task_type_to_use is None:
257-
task_type_to_use = "RETRIEVAL_QUERY" # Default to RETRIEVAL_QUERY
270+
task_type_to_use = "RETRIEVAL_QUERY"
258271
try:
259272
request: EmbedContentRequest = self._prepare_request(
260273
text=text,
@@ -266,3 +279,93 @@ def embed_query(
266279
except Exception as e:
267280
raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
268281
return list(result.embedding.values)
282+
283+
async def aembed_documents(
284+
self,
285+
texts: List[str],
286+
*,
287+
batch_size: int = _DEFAULT_BATCH_SIZE,
288+
task_type: Optional[str] = None,
289+
titles: Optional[List[str]] = None,
290+
output_dimensionality: Optional[int] = None,
291+
) -> List[List[float]]:
292+
"""Embed a list of strings using the `batch endpoint <https://ai.google.dev/api/embeddings#method:-models.batchembedcontents>`__.
293+
294+
Google Generative AI currently sets a max batch size of 100 strings.
295+
296+
Args:
297+
texts: List[str] The list of strings to embed.
298+
batch_size: [int] The batch size of embeddings to send to the model
299+
task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype>`__
300+
titles: An optional list of titles for texts provided.
301+
Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
302+
output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings#EmbedContentRequest>`__.
303+
Returns:
304+
List of embeddings, one for each text.
305+
"""
306+
embeddings: List[List[float]] = []
307+
batch_start_index = 0
308+
for batch in GoogleGenerativeAIEmbeddings._prepare_batches(texts, batch_size):
309+
if titles:
310+
titles_batch = titles[
311+
batch_start_index : batch_start_index + len(batch)
312+
]
313+
batch_start_index += len(batch)
314+
else:
315+
titles_batch = [None] * len(batch) # type: ignore[list-item]
316+
317+
requests = [
318+
self._prepare_request(
319+
text=text,
320+
task_type=task_type,
321+
title=title,
322+
output_dimensionality=output_dimensionality,
323+
)
324+
for text, title in zip(batch, titles_batch)
325+
]
326+
327+
try:
328+
result = await self.async_client.batch_embed_contents(
329+
BatchEmbedContentsRequest(requests=requests, model=self.model)
330+
)
331+
except Exception as e:
332+
raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
333+
embeddings.extend([list(e.values) for e in result.embeddings])
334+
return embeddings
335+
336+
async def aembed_query(
337+
self,
338+
text: str,
339+
*,
340+
task_type: Optional[str] = None,
341+
title: Optional[str] = None,
342+
output_dimensionality: Optional[int] = None,
343+
) -> List[float]:
344+
"""Embed a text, using the `non-batch endpoint <https://ai.google.dev/api/embeddings#method:-models.embedcontent>`__.
345+
346+
Args:
347+
text: The text to embed.
348+
task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype>`__
349+
title: An optional title for the text.
350+
Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
351+
output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings#EmbedContentRequest>`__.
352+
353+
Returns:
354+
Embedding for the text.
355+
"""
356+
task_type_to_use = task_type if task_type else self.task_type
357+
if task_type_to_use is None:
358+
task_type_to_use = "RETRIEVAL_QUERY"
359+
try:
360+
request: EmbedContentRequest = self._prepare_request(
361+
text=text,
362+
task_type=task_type,
363+
title=title,
364+
output_dimensionality=output_dimensionality,
365+
)
366+
result: EmbedContentResponse = await self.async_client.embed_content(
367+
request
368+
)
369+
except Exception as e:
370+
raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
371+
return list(result.embedding.values)

libs/genai/tests/integration_tests/test_chat_models.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,20 @@ def test_chat_google_genai_invoke_with_image() -> None:
130130
"""Test invoke tokens with image from ChatGoogleGenerativeAI."""
131131
llm = ChatGoogleGenerativeAI(model=_IMAGE_OUTPUT_MODEL)
132132

133-
result = llm.invoke(
134-
"Generate an image of a cat. Then, say meow!",
135-
config=dict(tags=["meow"]),
136-
generation_config=dict(
137-
top_k=2, top_p=1, temperature=0.7, response_modalities=["TEXT", "IMAGE"]
138-
),
139-
)
133+
for _ in range(3):
134+
result = llm.invoke(
135+
"Generate an image of a cat. Then, say meow!",
136+
config=dict(tags=["meow"]),
137+
generation_config=dict(
138+
top_k=2, top_p=1, temperature=0.7, response_modalities=["TEXT", "IMAGE"]
139+
),
140+
)
141+
if (
142+
isinstance(result.content, list)
143+
and len(result.content) > 0
144+
and isinstance(result.content[0], dict)
145+
):
146+
break
140147
assert isinstance(result, AIMessage)
141148
assert isinstance(result.content, list)
142149
assert isinstance(result.content[0], dict)
@@ -155,11 +162,18 @@ def test_chat_google_genai_invoke_with_modalities() -> None:
155162
response_modalities=[Modality.TEXT, Modality.IMAGE], # type: ignore[list-item]
156163
)
157164

158-
result = llm.invoke(
159-
"Generate an image of a cat. Then, say meow!",
160-
config=dict(tags=["meow"]),
161-
generation_config=dict(top_k=2, top_p=1, temperature=0.7),
162-
)
165+
for _ in range(3):
166+
result = llm.invoke(
167+
"Generate an image of a cat. Then, say meow!",
168+
config=dict(tags=["meow"]),
169+
generation_config=dict(top_k=2, top_p=1, temperature=0.7),
170+
)
171+
if (
172+
isinstance(result.content, list)
173+
and len(result.content) > 0
174+
and isinstance(result.content[0], dict)
175+
):
176+
break
163177
assert isinstance(result, AIMessage)
164178
assert isinstance(result.content, list)
165179
assert isinstance(result.content[0], dict)

libs/genai/tests/integration_tests/test_embeddings.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from langchain_google_genai._common import GoogleGenerativeAIError
66
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
77

8-
_MODEL = "models/embedding-001"
8+
_MODEL = "models/gemini-embedding-001"
9+
_OUTPUT_DIMENSIONALITY = 768
910

1011

1112
@pytest.mark.parametrize(
@@ -19,7 +20,7 @@
1920
def test_embed_query_different_lengths(query: str) -> None:
2021
"""Test embedding queries of different lengths."""
2122
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
22-
result = model.embed_query(query)
23+
result = model.embed_query(query, output_dimensionality=_OUTPUT_DIMENSIONALITY)
2324
assert len(result) == 768
2425
assert isinstance(result, list)
2526

@@ -35,7 +36,9 @@ def test_embed_query_different_lengths(query: str) -> None:
3536
async def test_aembed_query_different_lengths(query: str) -> None:
3637
"""Test embedding queries of different lengths."""
3738
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
38-
result = await model.aembed_query(query)
39+
result = await model.aembed_query(
40+
query, output_dimensionality=_OUTPUT_DIMENSIONALITY
41+
)
3942
assert len(result) == 768
4043
assert isinstance(result, list)
4144

@@ -45,7 +48,9 @@ def test_embed_documents() -> None:
4548
model = GoogleGenerativeAIEmbeddings(
4649
model=_MODEL,
4750
)
48-
result = model.embed_documents(["Hello world", "Good day, world"])
51+
result = model.embed_documents(
52+
["Hello world", "Good day, world"], output_dimensionality=_OUTPUT_DIMENSIONALITY
53+
)
4954
assert len(result) == 2
5055
assert len(result[0]) == 768
5156
assert len(result[1]) == 768
@@ -58,7 +63,9 @@ async def test_aembed_documents() -> None:
5863
model = GoogleGenerativeAIEmbeddings(
5964
model=_MODEL,
6065
)
61-
result = await model.aembed_documents(["Hello world", "Good day, world"])
66+
result = await model.aembed_documents(
67+
["Hello world", "Good day, world"], output_dimensionality=_OUTPUT_DIMENSIONALITY
68+
)
6269
assert len(result) == 2
6370
assert len(result[0]) == 768
6471
assert len(result[1]) == 768
@@ -69,23 +76,25 @@ async def test_aembed_documents() -> None:
6976
def test_invalid_model_error_handling() -> None:
7077
"""Test error handling with an invalid model name."""
7178
with pytest.raises(GoogleGenerativeAIError):
72-
GoogleGenerativeAIEmbeddings(model="invalid_model").embed_query("Hello world")
79+
GoogleGenerativeAIEmbeddings(model="invalid_model").embed_query(
80+
"Hello world", output_dimensionality=_OUTPUT_DIMENSIONALITY
81+
)
7382

7483

7584
def test_invalid_api_key_error_handling() -> None:
7685
"""Test error handling with an invalid API key."""
7786
with pytest.raises(GoogleGenerativeAIError):
7887
GoogleGenerativeAIEmbeddings(
7988
model=_MODEL, google_api_key=SecretStr("invalid_key")
80-
).embed_query("Hello world")
89+
).embed_query("Hello world", output_dimensionality=_OUTPUT_DIMENSIONALITY)
8190

8291

8392
def test_embed_documents_consistency() -> None:
8493
"""Test embedding consistency for the same document."""
8594
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
8695
doc = "Consistent document for testing"
87-
result1 = model.embed_documents([doc])
88-
result2 = model.embed_documents([doc])
96+
result1 = model.embed_documents([doc], output_dimensionality=_OUTPUT_DIMENSIONALITY)
97+
result2 = model.embed_documents([doc], output_dimensionality=_OUTPUT_DIMENSIONALITY)
8998
assert result1 == result2
9099

91100

@@ -94,8 +103,12 @@ def test_embed_documents_quality() -> None:
94103
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
95104
similar_docs = ["Document A", "Similar Document A"]
96105
dissimilar_docs = ["Document A", "Completely Different Zebra"]
97-
similar_embeddings = model.embed_documents(similar_docs)
98-
dissimilar_embeddings = model.embed_documents(dissimilar_docs)
106+
similar_embeddings = model.embed_documents(
107+
similar_docs, output_dimensionality=_OUTPUT_DIMENSIONALITY
108+
)
109+
dissimilar_embeddings = model.embed_documents(
110+
dissimilar_docs, output_dimensionality=_OUTPUT_DIMENSIONALITY
111+
)
99112
similar_distance = np.linalg.norm(
100113
np.array(similar_embeddings[0]) - np.array(similar_embeddings[1])
101114
)
@@ -109,16 +122,20 @@ def test_embed_query_task_type() -> None:
109122
"""Test for task_type"""
110123

111124
embeddings = GoogleGenerativeAIEmbeddings(model=_MODEL, task_type="clustering")
112-
emb = embeddings.embed_query("How does alphafold work?", output_dimensionality=768)
125+
emb = embeddings.embed_query(
126+
"How does alphafold work?", output_dimensionality=_OUTPUT_DIMENSIONALITY
127+
)
113128

114129
embeddings2 = GoogleGenerativeAIEmbeddings(model=_MODEL)
115130
emb2 = embeddings2.embed_query(
116-
"How does alphafold work?", task_type="clustering", output_dimensionality=768
131+
"How does alphafold work?",
132+
task_type="clustering",
133+
output_dimensionality=_OUTPUT_DIMENSIONALITY,
117134
)
118135

119136
embeddings3 = GoogleGenerativeAIEmbeddings(model=_MODEL)
120137
emb3 = embeddings3.embed_query(
121-
"How does alphafold work?", output_dimensionality=768
138+
"How does alphafold work?", output_dimensionality=_OUTPUT_DIMENSIONALITY
122139
)
123140

124141
assert emb == emb2

0 commit comments

Comments
 (0)