Skip to content

Commit 831a83a

Browse files
authored
fix batched prefill (#3887)
1 parent b55f80c commit 831a83a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

lmdeploy/pytorch/model_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs):
443443
# position_ids
444444
indices = attention_mask.long().cumsum(-1) - 1
445445
position_ids = indices + history_seqlens.unsqueeze(-1)
446-
indices[1:] += q_seqlens[:-1, None]
446+
indices[1:] += q_seqlens.cumsum(0)[:-1, None]
447447
position_ids_1d = position_ids.new_empty(num_tokens)
448448
position_ids_1d[indices.flatten()] = position_ids.flatten()
449449
return attention_mask, position_ids_1d

0 commit comments

Comments
 (0)