@@ -245,6 +245,13 @@ def _on_end_session(self, reqs: Request, **kwargs):
245245
246246 def _on_add_message (self , reqs : Request , ** kwargs ):
247247 """on add message callback."""
248+
249+ def __update_bad_words (msg ):
250+ """update bad words."""
251+ sampling_param = msg .sampling_param
252+ if sampling_param .ignore_eos :
253+ sampling_param .bad_words .append (self .model_config .eos_token_id )
254+
248255 for req in reqs :
249256 session_id = req .data ['session_id' ]
250257 if session_id not in self .scheduler .sessions :
@@ -265,13 +272,15 @@ def _on_add_message(self, reqs: Request, **kwargs):
265272 sampling_param = req .data ['sampling_param' ],
266273 adapter_name = req .data ['adapter_name' ])
267274 msg = next (iter (sess .sequences .values ()))
275+ __update_bad_words (msg )
268276 self .scheduler .add_sequence (msg )
269277 else :
270278 msg = next (iter (sess .sequences .values ()))
271279 msg .update_token_ids (req .data ['token_ids' ])
272280 msg .remain_output_len = req .data ['max_request_output_len' ]
273281 msg .sampling_param = req .data ['sampling_param' ]
274282 msg .status = MessageStatus .WAITING
283+ __update_bad_words (msg )
275284
276285 msg .sender_id = req .sender_id
277286 msg .req_id = req .req_id
@@ -408,9 +417,8 @@ def _stopping_criteria(self, msg: SchedulerSequence, next_token_id: int):
408417 """
409418
410419 # check eof
411- def _check_eof (sampling_param , next_token_id , eos_token_id ):
412- return (not sampling_param .ignore_eos
413- ) and next_token_id == eos_token_id
420+ def _check_eof (next_token_id , eos_token_id ):
421+ return next_token_id == eos_token_id
414422
415423 def _check_stop_word (sampling_param , next_token_id ):
416424 return (sampling_param .stop_words is not None
@@ -426,8 +434,7 @@ def _check_session_len(msg, max_session_len):
426434 return session_len >= max_session_len
427435
428436 sampling_param = msg .sampling_param
429- if _check_eof (sampling_param , next_token_id ,
430- self .model_config .eos_token_id ):
437+ if _check_eof (next_token_id , self .model_config .eos_token_id ):
431438 return True
432439 if _check_stop_word (sampling_param , next_token_id ):
433440 return True
@@ -495,8 +502,7 @@ def _can_output_token(self, token: torch.Tensor, msg: SchedulerSequence):
495502 """check if output is necessary."""
496503 if isinstance (token , torch .Tensor ):
497504 token = token .item ()
498- ignore_eos = msg .sampling_param .ignore_eos
499- if not ignore_eos and token == self .model_config .eos_token_id :
505+ if token == self .model_config .eos_token_id :
500506 return False
501507
502508 stop_words = msg .sampling_param .stop_words
0 commit comments