@@ -772,7 +772,7 @@ def text_generation(
772772 raise ValueError (
773773 f"The tokenizer's bos_token_id={ self .tokenizer .bos_token_id } must be the same as the model's bos_token_id={ self .bos_token_id } ."
774774 )
775- if not verify_eos_tokens_in_tokenizer (self .eos_token_ids , self .tokenizer ):
775+ if not verify_eos_tokens_in_pretrained_tokenizer (self .eos_token_ids , self .tokenizer ):
776776 raise ValueError (
777777 f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={ self .eos_token_ids } ."
778778 )
@@ -1300,25 +1300,28 @@ def generate(
13001300 )
13011301 max_seq_len = self .max_cache_size
13021302
1303+ # Prefill.
13031304 self .stats .on_sampling_begin ()
13041305 logits = self .forward (
1305- input_ids = torch .tensor (prompt_tokens , dtype = torch .long , device = self .device ). unsqueeze ( 0 ) ,
1306- cache_position = torch .arange (len (prompt_tokens ), dtype = torch .long , device = self .device ),
1306+ input_ids = torch .tensor (prompt_tokens , dtype = torch .long , device = self .device ),
1307+ cache_position = torch .arange (len (prompt_tokens [ 0 ] ), dtype = torch .long , device = self .device ),
13071308 input_features = input_features ,
13081309 )
13091310 self .stats .on_sampling_end ()
1310- next_token = torch .argmax (logits , dim = - 1 )[0 , - 1 ].item ()
13111311 self .stats .on_prompt_eval_end ()
1312- first_token_generated = False
13131312
1314- generated_tokens = prompt_tokens + [next_token ]
1313+ next_token = torch .argmax (logits [:, - 1 , :], dim = - 1 ).item ()
1314+ generated_tokens = [next_token ]
1315+ print (self .tokenizer .decode ([next_token ]), end = "" )
13151316
1316- while len (generated_tokens ) < max_seq_len :
1317+ # Token-by-token generation.
1318+ first_token_generated = False
1319+ while len (generated_tokens ) + len (prompt_tokens ) < max_seq_len :
13171320 self .stats .on_sampling_begin ()
13181321 logits = self .forward (
13191322 input_ids = torch .tensor ([next_token ], dtype = torch .long , device = self .device ).unsqueeze (0 ),
13201323 cache_position = torch .tensor (
1321- [pos_base + len (generated_tokens ) - 1 ],
1324+ [pos_base + len (generated_tokens ) + len ( prompt_tokens ) - 1 ],
13221325 dtype = torch .long ,
13231326 device = self .device ,
13241327 ),
@@ -1328,20 +1331,20 @@ def generate(
13281331 self .stats .on_first_token ()
13291332 first_token_generated = True
13301333
1331- next_token = torch .argmax (logits , dim = - 1 ).item ()
1334+ next_token = torch .argmax (logits [:, - 1 , :] , dim = - 1 ).item ()
13321335 generated_tokens .append (next_token )
1336+ print (self .tokenizer .decode ([next_token ]), end = "" )
13331337
1334- if next_token in self .eos_token_ids :
1338+ if next_token == self .eos_token_id :
13351339 break
13361340
13371341 self .stats .set_num_generated_tokens (len (generated_tokens ) - len (prompt_tokens ))
1338-
13391342 return generated_tokens if echo else generated_tokens [len (prompt_tokens ) :]
13401343
13411344 def text_generation (
13421345 self ,
13431346 processor : "ProcessorMixin" ,
1344- tokenizer : " PreTrainedTokenizer" ,
1347+ tokenizer : PreTrainedTokenizer ,
13451348 input_conversation : List [Dict ],
13461349 echo : bool = True ,
13471350 max_seq_len : Optional [int ] = None ,
@@ -1368,22 +1371,21 @@ def text_generation(
13681371 raise ValueError (
13691372 f"The tokenizer's bos_token_id={ self .tokenizer .bos_token_id } must be the same as the model's bos_token_id={ self .bos_token_id } ."
13701373 )
1371- if not verify_eos_tokens_in_tokenizer (self .eos_token_ids , self .tokenizer ):
1374+ if isinstance ( self . tokenizer , PreTrainedTokenizer ) and verify_eos_tokens_in_pretrained_tokenizer (self .eos_token_id , self .tokenizer ):
13721375 raise ValueError (
1373- f"The tokenizer's eos_token_id does not match with the model's eos_token_ids ={ self .eos_token_ids } ."
1376+ f"The tokenizer's eos_token_id does not match with the model's eos_token_id ={ self .eos_token_id } ."
13741377 )
13751378
13761379 # Reset stats for a new generation
13771380 self .stats .reset ()
13781381 self .stats .on_inference_start ()
13791382
13801383 inputs = processor .apply_chat_template (input_conversation )
1381- prompt_tokens = self .tokenizer .encode (inputs ["input_ids" ])
13821384 self .stats .on_token_encode_end ()
1383- self .stats .set_num_prompt_tokens (len (prompt_tokens ))
1385+ self .stats .set_num_prompt_tokens (len (inputs [ "input_ids" ][ 0 ] ))
13841386
13851387 generated_tokens = self .generate (
1386- prompt_tokens = prompt_tokens ,
1388+ prompt_tokens = inputs [ "input_ids" ] ,
13871389 input_features = inputs ["input_features" ],
13881390 echo = echo ,
13891391 max_seq_len = max_seq_len ,
0 commit comments