Skip to content

Commit 8be34db

Browse files
authored
added post-processing for local gemma (#40)
* added post-processing for local gemma * fixes after review
1 parent 62ed23f commit 8be34db

File tree

1 file changed

+47
-8
lines changed
  • libs/vertexai/langchain_google_vertexai

1 file changed

+47
-8
lines changed

libs/vertexai/langchain_google_vertexai/gemma.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,19 @@ def gemma_messages_to_prompt(history: List[BaseMessage]) -> str:
5353
return "".join(messages)
5454

5555

56+
def _parse_gemma_chat_response(response: str) -> str:
57+
"""Removes chat history from the response."""
58+
pattern = "<start_of_turn>model\n"
59+
pos = response.rfind(pattern)
60+
if pos == -1:
61+
return response
62+
text = response[(pos + len(pattern)) :]
63+
pos = text.find("<start_of_turn>user\n")
64+
if pos > 0:
65+
return text[:pos]
66+
return text
67+
68+
5669
class _GemmaBase(BaseModel):
5770
max_tokens: Optional[int] = None
5871
"""The maximum number of tokens to generate."""
@@ -98,6 +111,9 @@ class GemmaChatVertexAIModelGarden(_GemmaBase, _BaseVertexAIModelGarden, BaseCha
98111
"top_k",
99112
"max_tokens",
100113
]
114+
parse_response: bool = False
115+
"""Whether to post-process the chat response and clean repeations """
116+
"""or multi-turn statements."""
101117

102118
@property
103119
def _llm_type(self) -> str:
@@ -120,6 +136,8 @@ def _generate(
120136
request["prompt"] = gemma_messages_to_prompt(messages)
121137
output = self.client.predict(endpoint=self.endpoint_path, instances=[request])
122138
text = output.predictions[0]
139+
if self.parse_response or kwargs.get("parse_response"):
140+
text = _parse_gemma_chat_response(text)
123141
if stop:
124142
text = enforce_stop_tokens(text, stop)
125143
generations = [
@@ -143,6 +161,8 @@ async def _agenerate(
143161
endpoint=self.endpoint_path, instances=[request]
144162
)
145163
text = output.predictions[0]
164+
if self.parse_response or kwargs.get("parse_response"):
165+
text = _parse_gemma_chat_response(text)
146166
if stop:
147167
text = enforce_stop_tokens(text, stop)
148168
generations = [
@@ -183,6 +203,11 @@ def _default_params(self) -> Dict[str, Any]:
183203
params = {"max_length": self.max_tokens}
184204
return {k: v for k, v in params.items() if v is not None}
185205

206+
def _get_params(self, **kwargs) -> Dict[str, Any]:
207+
mapping = {"max_tokens": "max_length"}
208+
params = {mapping[k]: v for k, v in kwargs.items() if k in mapping}
209+
return {**self._default_params, **params}
210+
186211

187212
class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM):
188213
"""Local gemma chat model loaded from Kaggle."""
@@ -195,7 +220,7 @@ def _generate(
195220
**kwargs: Any,
196221
) -> LLMResult:
197222
"""Run the LLM on the given prompt and input."""
198-
params = {"max_length": self.max_tokens} if self.max_tokens else {}
223+
params = self._get_params(**kwargs)
199224
results = self.client.generate(prompts, **params)
200225
results = [results] if isinstance(results, str) else results
201226
if stop:
@@ -209,16 +234,22 @@ def _llm_type(self) -> str:
209234

210235

211236
class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel):
237+
parse_response: bool = False
238+
"""Whether to post-process the chat response and clean repeations """
239+
"""or multi-turn statements."""
240+
212241
def _generate(
213242
self,
214243
messages: List[BaseMessage],
215244
stop: Optional[List[str]] = None,
216245
run_manager: Optional[CallbackManagerForLLMRun] = None,
217246
**kwargs: Any,
218247
) -> ChatResult:
219-
params = {"max_length": self.max_tokens} if self.max_tokens else {}
248+
params = self._get_params(**kwargs)
220249
prompt = gemma_messages_to_prompt(messages)
221250
text = self.client.generate(prompt, **params)
251+
if self.parse_response or kwargs.get("parse_response"):
252+
text = _parse_gemma_chat_response(text)
222253
if stop:
223254
text = enforce_stop_tokens(text, stop)
224255
generation = ChatGeneration(message=AIMessage(content=text))
@@ -268,9 +299,15 @@ def _default_params(self) -> Dict[str, Any]:
268299
params = {"max_length": self.max_tokens}
269300
return {k: v for k, v in params.items() if v is not None}
270301

302+
def _get_params(self, **kwargs) -> Dict[str, Any]:
303+
mapping = {"max_tokens": "max_length"}
304+
params = {mapping[k]: v for k, v in kwargs.items() if k in mapping}
305+
return {**self._default_params, **params}
306+
271307
def _run(self, prompt: str, **kwargs: Any) -> str:
272308
inputs = self.tokenizer(prompt, return_tensors="pt")
273-
generate_ids = self.client.generate(inputs.input_ids, **kwargs)
309+
params = self._get_params(**kwargs)
310+
generate_ids = self.client.generate(inputs.input_ids, **params)
274311
return self.tokenizer.batch_decode(
275312
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
276313
)[0]
@@ -287,8 +324,7 @@ def _generate(
287324
**kwargs: Any,
288325
) -> LLMResult:
289326
"""Run the LLM on the given prompt and input."""
290-
params = {"max_length": self.max_tokens} if self.max_tokens else {}
291-
results = [self._run(prompt, **params) for prompt in prompts]
327+
results = [self._run(prompt, **kwargs) for prompt in prompts]
292328
if stop:
293329
results = [enforce_stop_tokens(text, stop) for text in results]
294330
return LLMResult(generations=[[Generation(text=text)] for text in results])
@@ -300,7 +336,9 @@ def _llm_type(self) -> str:
300336

301337

302338
class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel):
303-
"""Local gemma chat model loaded from HuggingFace."""
339+
parse_response: bool = False
340+
"""Whether to post-process the chat response and clean repeations """
341+
"""or multi-turn statements."""
304342

305343
def _generate(
306344
self,
@@ -309,9 +347,10 @@ def _generate(
309347
run_manager: Optional[CallbackManagerForLLMRun] = None,
310348
**kwargs: Any,
311349
) -> ChatResult:
312-
params = {"max_length": self.max_tokens} if self.max_tokens else {}
313350
prompt = gemma_messages_to_prompt(messages)
314-
text = self._run(prompt, **params)
351+
text = self._run(prompt, **kwargs)
352+
if self.parse_response or kwargs.get("parse_response"):
353+
text = _parse_gemma_chat_response(text)
315354
if stop:
316355
text = enforce_stop_tokens(text, stop)
317356
generation = ChatGeneration(message=AIMessage(content=text))

0 commit comments

Comments
 (0)