1717 GoogleGenerativeAIError ,
1818 get_client_info ,
1919)
20- from langchain_google_genai ._genai_extension import build_generative_service
20+ from langchain_google_genai ._genai_extension import (
21+ build_generative_async_service ,
22+ build_generative_service ,
23+ )
2124
2225_MAX_TOKENS_PER_BATCH = 20000
2326_DEFAULT_BATCH_SIZE = 100
@@ -29,8 +32,8 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
2932 To use, you must have either:
3033
3134 1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
32- 2. Pass your API key using the google_api_key kwarg
33- to the GoogleGenerativeAIEmbeddings constructor.
35+ 2. Pass your API key using the google_api_key kwarg to the
36+ GoogleGenerativeAIEmbeddings constructor.
3437
3538 Example:
3639 .. code-block:: python
@@ -42,6 +45,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
4245 """
4346
4447 client : Any = None #: :meta private:
48+ async_client : Any = None #: :meta private:
4549 model : str = Field (
4650 ...,
4751 description = "The name of the embedding model to use. "
@@ -100,6 +104,13 @@ def validate_environment(self) -> Self:
100104 client_options = self .client_options ,
101105 transport = self .transport ,
102106 )
107+ self .async_client = build_generative_async_service (
108+ credentials = self .credentials ,
109+ api_key = google_api_key ,
110+ client_info = client_info ,
111+ client_options = self .client_options ,
112+ transport = self .transport ,
113+ )
103114 return self
104115
105116 @staticmethod
@@ -166,12 +177,12 @@ def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
166177 def _prepare_request (
167178 self ,
168179 text : str ,
180+ * ,
169181 task_type : Optional [str ] = None ,
170182 title : Optional [str ] = None ,
171183 output_dimensionality : Optional [int ] = None ,
172184 ) -> EmbedContentRequest :
173185 task_type = self .task_type or task_type or "RETRIEVAL_DOCUMENT"
174- # https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
175186 request = EmbedContentRequest (
176187 content = {"parts" : [{"text" : text }]},
177188 model = self .model ,
@@ -190,16 +201,17 @@ def embed_documents(
190201 titles : Optional [List [str ]] = None ,
191202 output_dimensionality : Optional [int ] = None ,
192203 ) -> List [List [float ]]:
193- """Embed a list of strings. Google Generative AI currently
194- sets a max batch size of 100 strings.
204+ """Embed a list of strings using the `batch endpoint <https://ai.google.dev/api/embeddings#method:-models.batchembedcontents>`__.
205+
206+ Google Generative AI currently sets a max batch size of 100 strings.
195207
196208 Args:
197209 texts: List[str] The list of strings to embed.
198210 batch_size: [int] The batch size of embeddings to send to the model
199- task_type: `task_type <https://ai.google.dev/api/rest/v1/TaskType >`__
211+ task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype >`__
200212 titles: An optional list of titles for texts provided.
201- Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
202- output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/rest/v1/models/batchEmbedContents #EmbedContentRequest>`__.
213+ Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
214+ output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings #EmbedContentRequest>`__.
203215 Returns:
204216 List of embeddings, one for each text.
205217 """
@@ -236,25 +248,26 @@ def embed_documents(
236248 def embed_query (
237249 self ,
238250 text : str ,
251+ * ,
239252 task_type : Optional [str ] = None ,
240253 title : Optional [str ] = None ,
241254 output_dimensionality : Optional [int ] = None ,
242255 ) -> List [float ]:
243- """Embed a text, using the `non-batch endpoint <https://ai.google.dev/api/rest/v1/models/embedContent#EmbedContentRequest >`__.
256+ """Embed a text, using the `non-batch endpoint <https://ai.google.dev/api/embeddings#method:-models.embedcontent >`__.
244257
245258 Args:
246259 text: The text to embed.
247- task_type: `task_type <https://ai.google.dev/api/rest/v1/TaskType >`__
260+ task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype >`__
248261 title: An optional title for the text.
249- Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
250- output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/rest/v1/models/batchEmbedContents #EmbedContentRequest>`__.
262+ Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
263+ output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings #EmbedContentRequest>`__.
251264
252265 Returns:
253266 Embedding for the text.
254267 """
255268 task_type_to_use = task_type if task_type else self .task_type
256269 if task_type_to_use is None :
257- task_type_to_use = "RETRIEVAL_QUERY" # Default to RETRIEVAL_QUERY
270+ task_type_to_use = "RETRIEVAL_QUERY"
258271 try :
259272 request : EmbedContentRequest = self ._prepare_request (
260273 text = text ,
@@ -266,3 +279,93 @@ def embed_query(
266279 except Exception as e :
267280 raise GoogleGenerativeAIError (f"Error embedding content: { e } " ) from e
268281 return list (result .embedding .values )
282+
283+ async def aembed_documents (
284+ self ,
285+ texts : List [str ],
286+ * ,
287+ batch_size : int = _DEFAULT_BATCH_SIZE ,
288+ task_type : Optional [str ] = None ,
289+ titles : Optional [List [str ]] = None ,
290+ output_dimensionality : Optional [int ] = None ,
291+ ) -> List [List [float ]]:
292+ """Embed a list of strings using the `batch endpoint <https://ai.google.dev/api/embeddings#method:-models.batchembedcontents>`__.
293+
294+ Google Generative AI currently sets a max batch size of 100 strings.
295+
296+ Args:
297+ texts: List[str] The list of strings to embed.
298+ batch_size: [int] The batch size of embeddings to send to the model
299+ task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype>`__
300+ titles: An optional list of titles for texts provided.
301+ Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
302+ output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings#EmbedContentRequest>`__.
303+ Returns:
304+ List of embeddings, one for each text.
305+ """
306+ embeddings : List [List [float ]] = []
307+ batch_start_index = 0
308+ for batch in GoogleGenerativeAIEmbeddings ._prepare_batches (texts , batch_size ):
309+ if titles :
310+ titles_batch = titles [
311+ batch_start_index : batch_start_index + len (batch )
312+ ]
313+ batch_start_index += len (batch )
314+ else :
315+ titles_batch = [None ] * len (batch ) # type: ignore[list-item]
316+
317+ requests = [
318+ self ._prepare_request (
319+ text = text ,
320+ task_type = task_type ,
321+ title = title ,
322+ output_dimensionality = output_dimensionality ,
323+ )
324+ for text , title in zip (batch , titles_batch )
325+ ]
326+
327+ try :
328+ result = await self .async_client .batch_embed_contents (
329+ BatchEmbedContentsRequest (requests = requests , model = self .model )
330+ )
331+ except Exception as e :
332+ raise GoogleGenerativeAIError (f"Error embedding content: { e } " ) from e
333+ embeddings .extend ([list (e .values ) for e in result .embeddings ])
334+ return embeddings
335+
336+ async def aembed_query (
337+ self ,
338+ text : str ,
339+ * ,
340+ task_type : Optional [str ] = None ,
341+ title : Optional [str ] = None ,
342+ output_dimensionality : Optional [int ] = None ,
343+ ) -> List [float ]:
344+ """Embed a text, using the `non-batch endpoint <https://ai.google.dev/api/embeddings#method:-models.embedcontent>`__.
345+
346+ Args:
347+ text: The text to embed.
348+ task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype>`__
349+ title: An optional title for the text.
350+ Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
351+ output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings#EmbedContentRequest>`__.
352+
353+ Returns:
354+ Embedding for the text.
355+ """
356+ task_type_to_use = task_type if task_type else self .task_type
357+ if task_type_to_use is None :
358+ task_type_to_use = "RETRIEVAL_QUERY"
359+ try :
360+ request : EmbedContentRequest = self ._prepare_request (
361+ text = text ,
362+ task_type = task_type ,
363+ title = title ,
364+ output_dimensionality = output_dimensionality ,
365+ )
366+ result : EmbedContentResponse = await self .async_client .embed_content (
367+ request
368+ )
369+ except Exception as e :
370+ raise GoogleGenerativeAIError (f"Error embedding content: { e } " ) from e
371+ return list (result .embedding .values )
0 commit comments