@@ -53,6 +53,19 @@ def gemma_messages_to_prompt(history: List[BaseMessage]) -> str:
5353 return "" .join (messages )
5454
5555
56+ def _parse_gemma_chat_response (response : str ) -> str :
57+ """Removes chat history from the response."""
58+ pattern = "<start_of_turn>model\n "
59+ pos = response .rfind (pattern )
60+ if pos == - 1 :
61+ return response
62+ text = response [(pos + len (pattern )) :]
63+ pos = text .find ("<start_of_turn>user\n " )
64+ if pos > 0 :
65+ return text [:pos ]
66+ return text
67+
68+
5669class _GemmaBase (BaseModel ):
5770 max_tokens : Optional [int ] = None
5871 """The maximum number of tokens to generate."""
@@ -98,6 +111,9 @@ class GemmaChatVertexAIModelGarden(_GemmaBase, _BaseVertexAIModelGarden, BaseCha
98111 "top_k" ,
99112 "max_tokens" ,
100113 ]
114+ parse_response : bool = False
115+ """Whether to post-process the chat response and clean repeations """
116+ """or multi-turn statements."""
101117
102118 @property
103119 def _llm_type (self ) -> str :
@@ -120,6 +136,8 @@ def _generate(
120136 request ["prompt" ] = gemma_messages_to_prompt (messages )
121137 output = self .client .predict (endpoint = self .endpoint_path , instances = [request ])
122138 text = output .predictions [0 ]
139+ if self .parse_response or kwargs .get ("parse_response" ):
140+ text = _parse_gemma_chat_response (text )
123141 if stop :
124142 text = enforce_stop_tokens (text , stop )
125143 generations = [
@@ -143,6 +161,8 @@ async def _agenerate(
143161 endpoint = self .endpoint_path , instances = [request ]
144162 )
145163 text = output .predictions [0 ]
164+ if self .parse_response or kwargs .get ("parse_response" ):
165+ text = _parse_gemma_chat_response (text )
146166 if stop :
147167 text = enforce_stop_tokens (text , stop )
148168 generations = [
@@ -183,6 +203,11 @@ def _default_params(self) -> Dict[str, Any]:
183203 params = {"max_length" : self .max_tokens }
184204 return {k : v for k , v in params .items () if v is not None }
185205
206+ def _get_params (self , ** kwargs ) -> Dict [str , Any ]:
207+ mapping = {"max_tokens" : "max_length" }
208+ params = {mapping [k ]: v for k , v in kwargs .items () if k in mapping }
209+ return {** self ._default_params , ** params }
210+
186211
187212class GemmaLocalKaggle (_GemmaLocalKaggleBase , BaseLLM ):
188213 """Local gemma chat model loaded from Kaggle."""
@@ -195,7 +220,7 @@ def _generate(
195220 ** kwargs : Any ,
196221 ) -> LLMResult :
197222 """Run the LLM on the given prompt and input."""
198- params = { "max_length" : self .max_tokens } if self . max_tokens else {}
223+ params = self ._get_params ( ** kwargs )
199224 results = self .client .generate (prompts , ** params )
200225 results = [results ] if isinstance (results , str ) else results
201226 if stop :
@@ -209,16 +234,22 @@ def _llm_type(self) -> str:
209234
210235
211236class GemmaChatLocalKaggle (_GemmaLocalKaggleBase , BaseChatModel ):
237+ parse_response : bool = False
238+ """Whether to post-process the chat response and clean repeations """
239+ """or multi-turn statements."""
240+
212241 def _generate (
213242 self ,
214243 messages : List [BaseMessage ],
215244 stop : Optional [List [str ]] = None ,
216245 run_manager : Optional [CallbackManagerForLLMRun ] = None ,
217246 ** kwargs : Any ,
218247 ) -> ChatResult :
219- params = { "max_length" : self .max_tokens } if self . max_tokens else {}
248+ params = self ._get_params ( ** kwargs )
220249 prompt = gemma_messages_to_prompt (messages )
221250 text = self .client .generate (prompt , ** params )
251+ if self .parse_response or kwargs .get ("parse_response" ):
252+ text = _parse_gemma_chat_response (text )
222253 if stop :
223254 text = enforce_stop_tokens (text , stop )
224255 generation = ChatGeneration (message = AIMessage (content = text ))
@@ -268,9 +299,15 @@ def _default_params(self) -> Dict[str, Any]:
268299 params = {"max_length" : self .max_tokens }
269300 return {k : v for k , v in params .items () if v is not None }
270301
302+ def _get_params (self , ** kwargs ) -> Dict [str , Any ]:
303+ mapping = {"max_tokens" : "max_length" }
304+ params = {mapping [k ]: v for k , v in kwargs .items () if k in mapping }
305+ return {** self ._default_params , ** params }
306+
271307 def _run (self , prompt : str , ** kwargs : Any ) -> str :
272308 inputs = self .tokenizer (prompt , return_tensors = "pt" )
273- generate_ids = self .client .generate (inputs .input_ids , ** kwargs )
309+ params = self ._get_params (** kwargs )
310+ generate_ids = self .client .generate (inputs .input_ids , ** params )
274311 return self .tokenizer .batch_decode (
275312 generate_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False
276313 )[0 ]
@@ -287,8 +324,7 @@ def _generate(
287324 ** kwargs : Any ,
288325 ) -> LLMResult :
289326 """Run the LLM on the given prompt and input."""
290- params = {"max_length" : self .max_tokens } if self .max_tokens else {}
291- results = [self ._run (prompt , ** params ) for prompt in prompts ]
327+ results = [self ._run (prompt , ** kwargs ) for prompt in prompts ]
292328 if stop :
293329 results = [enforce_stop_tokens (text , stop ) for text in results ]
294330 return LLMResult (generations = [[Generation (text = text )] for text in results ])
@@ -300,7 +336,9 @@ def _llm_type(self) -> str:
300336
301337
302338class GemmaChatLocalHF (_GemmaLocalHFBase , BaseChatModel ):
303- """Local gemma chat model loaded from HuggingFace."""
339+ parse_response : bool = False
340+ """Whether to post-process the chat response and clean repeations """
341+ """or multi-turn statements."""
304342
305343 def _generate (
306344 self ,
@@ -309,9 +347,10 @@ def _generate(
309347 run_manager : Optional [CallbackManagerForLLMRun ] = None ,
310348 ** kwargs : Any ,
311349 ) -> ChatResult :
312- params = {"max_length" : self .max_tokens } if self .max_tokens else {}
313350 prompt = gemma_messages_to_prompt (messages )
314- text = self ._run (prompt , ** params )
351+ text = self ._run (prompt , ** kwargs )
352+ if self .parse_response or kwargs .get ("parse_response" ):
353+ text = _parse_gemma_chat_response (text )
315354 if stop :
316355 text = enforce_stop_tokens (text , stop )
317356 generation = ChatGeneration (message = AIMessage (content = text ))
0 commit comments