Skip to content

Commit 960ac13

Browse files
authored
fixed gemma_hf (#33)
* fix gemma_hf
1 parent 6421180 commit 960ac13

File tree

4 files changed

+21
-3
lines changed

4 files changed

+21
-3
lines changed

libs/vertexai/langchain_google_vertexai/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from langchain_google_vertexai.chat_models import ChatVertexAI
44
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
55
from langchain_google_vertexai.gemma import (
6+
GemmaChatLocalHF,
67
GemmaChatLocalKaggle,
78
GemmaChatVertexAIModelGarden,
89
GemmaLocalHF,
@@ -12,6 +13,13 @@
1213
from langchain_google_vertexai.llms import VertexAI
1314
from langchain_google_vertexai.model_garden import VertexAIModelGarden
1415
from langchain_google_vertexai.vectorstores.vectorstores import VectorSearchVectorStore
16+
from langchain_google_vertexai.vision_models import (
17+
VertexAIImageCaptioning,
18+
VertexAIImageCaptioningChat,
19+
VertexAIImageEditorChat,
20+
VertexAIImageGeneratorChat,
21+
VertexAIVisualQnAChat,
22+
)
1523

1624
__all__ = [
1725
"ChatVertexAI",
@@ -29,4 +37,9 @@
2937
"PydanticFunctionsOutputParser",
3038
"create_structured_runnable",
3139
"VectorSearchVectorStore",
40+
"VertexAIImageCaptioning",
41+
"VertexAIImageCaptioningChat",
42+
"VertexAIImageEditorChat",
43+
"VertexAIImageGeneratorChat",
44+
"VertexAIVisualQnAChat",
3245
]

libs/vertexai/langchain_google_vertexai/gemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _generate(
197197
"""Run the LLM on the given prompt and input."""
198198
params = {"max_length": self.max_tokens} if self.max_tokens else {}
199199
results = self.client.generate(prompts, **params)
200-
results = results if isinstance(results, str) else [results]
200+
results = [results] if isinstance(results, str) else results
201201
if stop:
202202
results = [enforce_stop_tokens(text, stop) for text in results]
203203
return LLMResult(generations=[[Generation(text=result)] for result in results])
@@ -268,7 +268,7 @@ def _default_params(self) -> Dict[str, Any]:
268268
params = {"max_length": self.max_tokens}
269269
return {k: v for k, v in params.items() if v is not None}
270270

271-
def _run(self, prompt: str, kwargs: Any) -> str:
271+
def _run(self, prompt: str, **kwargs: Any) -> str:
272272
inputs = self.tokenizer(prompt, return_tensors="pt")
273273
generate_ids = self.client.generate(inputs.input_ids, **kwargs)
274274
return self.tokenizer.batch_decode(

libs/vertexai/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "langchain-google-vertexai"
3-
version = "0.0.6"
3+
version = "0.0.7"
44
description = "An integration package connecting GoogleVertexAI and LangChain"
55
authors = []
66
readme = "README.md"

libs/vertexai/tests/unit_tests/test_imports.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
"PydanticFunctionsOutputParser",
1717
"create_structured_runnable",
1818
"VectorSearchVectorStore",
19+
"VertexAIImageCaptioning",
20+
"VertexAIImageCaptioningChat",
21+
"VertexAIImageEditorChat",
22+
"VertexAIImageGeneratorChat",
23+
"VertexAIVisualQnAChat",
1924
]
2025

2126

0 commit comments

Comments
 (0)