Skip to content

Commit c332efa

Browse files
authored
dependency version check & fix ignore_eos logic (#1099)
* version-check * lock peft version ---------
1 parent b88df3f commit c332efa

File tree

4 files changed

+24
-12
lines changed

4 files changed

+24
-12
lines changed

lmdeploy/pytorch/engine/engine.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,13 @@ def _on_end_session(self, reqs: Request, **kwargs):
245245

246246
def _on_add_message(self, reqs: Request, **kwargs):
247247
"""on add message callback."""
248+
249+
def __update_bad_words(msg):
250+
"""update bad words."""
251+
sampling_param = msg.sampling_param
252+
if sampling_param.ignore_eos:
253+
sampling_param.bad_words.append(self.model_config.eos_token_id)
254+
248255
for req in reqs:
249256
session_id = req.data['session_id']
250257
if session_id not in self.scheduler.sessions:
@@ -265,13 +272,15 @@ def _on_add_message(self, reqs: Request, **kwargs):
265272
sampling_param=req.data['sampling_param'],
266273
adapter_name=req.data['adapter_name'])
267274
msg = next(iter(sess.sequences.values()))
275+
__update_bad_words(msg)
268276
self.scheduler.add_sequence(msg)
269277
else:
270278
msg = next(iter(sess.sequences.values()))
271279
msg.update_token_ids(req.data['token_ids'])
272280
msg.remain_output_len = req.data['max_request_output_len']
273281
msg.sampling_param = req.data['sampling_param']
274282
msg.status = MessageStatus.WAITING
283+
__update_bad_words(msg)
275284

276285
msg.sender_id = req.sender_id
277286
msg.req_id = req.req_id
@@ -408,9 +417,8 @@ def _stopping_criteria(self, msg: SchedulerSequence, next_token_id: int):
408417
"""
409418

410419
# check eof
411-
def _check_eof(sampling_param, next_token_id, eos_token_id):
412-
return (not sampling_param.ignore_eos
413-
) and next_token_id == eos_token_id
420+
def _check_eof(next_token_id, eos_token_id):
421+
return next_token_id == eos_token_id
414422

415423
def _check_stop_word(sampling_param, next_token_id):
416424
return (sampling_param.stop_words is not None
@@ -426,8 +434,7 @@ def _check_session_len(msg, max_session_len):
426434
return session_len >= max_session_len
427435

428436
sampling_param = msg.sampling_param
429-
if _check_eof(sampling_param, next_token_id,
430-
self.model_config.eos_token_id):
437+
if _check_eof(next_token_id, self.model_config.eos_token_id):
431438
return True
432439
if _check_stop_word(sampling_param, next_token_id):
433440
return True
@@ -495,8 +502,7 @@ def _can_output_token(self, token: torch.Tensor, msg: SchedulerSequence):
495502
"""check if output is necessary."""
496503
if isinstance(token, torch.Tensor):
497504
token = token.item()
498-
ignore_eos = msg.sampling_param.ignore_eos
499-
if not ignore_eos and token == self.model_config.eos_token_id:
505+
if token == self.model_config.eos_token_id:
500506
return False
501507

502508
stop_words = msg.sampling_param.stop_words

lmdeploy/pytorch/engine/logits_process.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __call__(self, input_ids: torch.LongTensor,
6161
# bad words
6262
bad_words = self.sampling_param.bad_words
6363
if bad_words:
64+
bad_words = list(set(bad_words))
6465
bad_words_bias = new_scores.new_zeros(new_scores.size(1))
6566
bad_words_bias[bad_words] = filter_value
6667
new_scores += bad_words_bias[None]

lmdeploy/pytorch/messages.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,18 @@ def __hash__(self):
3535
def from_gen_config(self, gen_config: EngineGenerationConfig):
3636
"""from gen config."""
3737

38+
stop_words = gen_config.stop_words or []
39+
bad_words = gen_config.bad_words or []
40+
if gen_config.ignore_eos:
41+
bad_words += stop_words
3842
return SamplingParam(top_p=gen_config.top_p,
3943
top_k=gen_config.top_k,
4044
temperature=gen_config.temperature,
4145
repetition_penalty=gen_config.repetition_penalty,
4246
ignore_eos=gen_config.ignore_eos,
4347
random_seed=gen_config.random_seed,
44-
stop_words=gen_config.stop_words,
45-
bad_words=gen_config.bad_words)
48+
stop_words=stop_words,
49+
bad_words=bad_words)
4650

4751

4852
class MessageStatus(enum.Enum):

requirements/runtime.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ fire
33
fuzzywuzzy
44
mmengine-lite
55
numpy
6-
peft
6+
peft==0.7.1
77
pydantic>2.0.0
88
pynvml
99
safetensors
1010
sentencepiece
1111
shortuuid
1212
tiktoken
13-
torch
14-
transformers>=4.33.0
13+
torch<=2.1.2,>=2.0.0
14+
transformers>=4.33.0,<=4.37.1
15+
triton>=2.1.0,<2.2.0
1516
uvicorn

0 commit comments

Comments
 (0)