Skip to content

Commit c0b9776

Browse files
authored
optimize prefill preprocess (#3869)
* optimize prefill preprocess * fix * fix max seqlen
1 parent e2fcd16 commit c0b9776

File tree

3 files changed

+55
-23
lines changed

3 files changed

+55
-23
lines changed

lmdeploy/pytorch/backends/cuda/op_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def update_step_context(cls, step_context):
134134
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(kv_seqlens, dim=0, dtype=torch.int32), (1, 0))
135135
if not step_context.is_decoding:
136136
kv_start_loc = kv_seqlens.cumsum(0) - kv_seqlens
137-
kv_flatten_size = kv_seqlens.sum().item()
137+
kv_flatten_size = step_context.sum_kv_seqlen
138138
attn_metadata = attn_meta_cls(
139139
step_context.is_decoding,
140140
step_context.block_offsets,

lmdeploy/pytorch/engine/engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,8 +738,13 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
738738
if not is_decoding:
739739
seq_length = [len(tokens) for tokens in token_ids]
740740
seq_length = torch.tensor(seq_length, dtype=torch.long)
741+
max_q_seqlen = seq_length.max().item()
741742
else:
742743
seq_length = torch.ones(batch_size, dtype=torch.long)
744+
max_q_seqlen = 1
745+
kv_seqlens = seq_length + history_lengths
746+
max_kv_seqlen = kv_seqlens.max().item()
747+
sum_kv_seqlen = kv_seqlens.sum().item()
743748

744749
# block offsets
745750
block_offsets = self.scheduler.get_block_tables(messages)
@@ -759,6 +764,9 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
759764
block_offsets=block_offsets,
760765
is_decoding=is_decoding,
761766
num_ignored_history=num_ignored_history,
767+
max_q_seqlen=max_q_seqlen,
768+
max_kv_seqlen=max_kv_seqlen,
769+
sum_kv_seqlen=sum_kv_seqlen,
762770
model_metas=model_metas,
763771
)
764772

lmdeploy/pytorch/model_inputs.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ class ModelInputs:
131131
block_offsets: torch.LongTensor
132132
is_decoding: bool
133133
num_ignored_history: torch.LongTensor
134+
max_q_seqlen: int
135+
max_kv_seqlen: int
136+
sum_kv_seqlen: int
134137
local_adapter_ids: torch.LongTensor = None
135138
vision_inputs: VisionModelInputs = None
136139
cross_length: torch.LongTensor = None
@@ -143,6 +146,8 @@ def update(self, input_ids: torch.LongTensor):
143146
"""Update input ids."""
144147
assert self.is_decoding
145148
self.history_lengths = self.history_lengths + 1
149+
self.max_kv_seqlen += 1
150+
self.sum_kv_seqlen += self.seq_length.numel()
146151
if input_ids.dim() == 1:
147152
input_ids = input_ids[None, :]
148153
self.input_ids = input_ids
@@ -214,6 +219,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int):
214219
max_seq_len = self.seq_length[0].item()
215220
ret = []
216221
start = 0
222+
max_kv_seqlen = self.max_kv_seqlen
217223

218224
# for mllama
219225
history_cross_length = self.history_cross_length
@@ -227,13 +233,17 @@ def __make_next_vision_inputs(flatten_mms: List, start: int):
227233
vision_inputs = None
228234
end = min(max_seq_len, start + split_size)
229235

236+
max_q_seqlen = end - start
230237
inp = ModelInputs(
231238
input_ids=self.input_ids[:, start:end],
232239
seq_length=input_ids.new_tensor([end - start]),
233240
block_offsets=self.block_offsets,
234241
history_lengths=self.history_lengths + start,
235242
is_decoding=self.is_decoding,
236243
num_ignored_history=self.num_ignored_history,
244+
max_q_seqlen=max_q_seqlen,
245+
max_kv_seqlen=max_kv_seqlen,
246+
sum_kv_seqlen=max_kv_seqlen,
237247
local_adapter_ids=self.local_adapter_ids,
238248
vision_inputs=vision_inputs,
239249
model_metas=self.model_metas,
@@ -242,6 +252,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int):
242252
)
243253
ret.append(inp)
244254
history_cross_length = cross_length
255+
max_kv_seqlen += max_q_seqlen
245256

246257
start = end
247258

@@ -291,6 +302,9 @@ def make_dummy(cls,
291302
block_offsets=block_offsets,
292303
is_decoding=is_decoding,
293304
num_ignored_history=num_ignored_history,
305+
max_q_seqlen=1,
306+
max_kv_seqlen=1,
307+
sum_kv_seqlen=batch_size,
294308
)
295309

296310
def log_info(self):
@@ -316,6 +330,7 @@ class StepContext:
316330
q_start_loc: torch.LongTensor
317331
kv_caches: List
318332
is_decoding: bool
333+
sum_kv_seqlen: int
319334
local_adapter_ids: torch.LongTensor = None
320335
input_embeddings: torch.Tensor = None
321336
input_embedding_indexing: torch.Tensor = None
@@ -348,7 +363,6 @@ def new(
348363
"""
349364
q_seqlens = inputs.seq_length
350365
history_seqlens = inputs.history_lengths
351-
device = q_seqlens.device
352366

353367
input_multimodals = None
354368
if inputs.vision_inputs is not None:
@@ -360,16 +374,9 @@ def new(
360374
input_embeddings, input_embedding_indexing = \
361375
inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens)
362376

363-
# kv_seqlens
364-
if inputs.is_decoding:
365-
attention_mask = torch.ones_like(q_seqlens)[:, None]
366-
position_ids = history_seqlens.unsqueeze(-1).clone()
367-
else:
368-
max_q_seqlen = q_seqlens.max().item()
369-
mask_range = torch.arange(max_q_seqlen, device=device)[None, :]
370-
attention_mask = (mask_range < q_seqlens[:, None]).long()
371-
position_ids = attention_mask.long().cumsum(-1) - 1
372-
position_ids += history_seqlens.unsqueeze(-1)
377+
# position ids
378+
attention_mask, position_ids = cls.get_mask_and_position_ids(inputs)
379+
position_ids = position_ids[None] # [num_tokens] -> [1, num_tokens]
373380
q_start_loc = q_seqlens.cumsum(0) - q_seqlens
374381

375382
# cross
@@ -378,8 +385,6 @@ def new(
378385
if inputs.cross_length is not None:
379386
cross_kv_seqlens = (inputs.cross_length + inputs.history_cross_length)
380387

381-
# position ids 1d
382-
position_ids = cls.get_position_ids_1d(position_ids, q_seqlens)[None]
383388
# seq_len + history_length
384389
kv_seqlens = q_seqlens + history_seqlens
385390
kv_seqlens -= inputs.num_ignored_history
@@ -398,6 +403,7 @@ def new(
398403
q_start_loc=q_start_loc,
399404
kv_caches=kv_caches,
400405
is_decoding=inputs.is_decoding,
406+
sum_kv_seqlen=inputs.sum_kv_seqlen,
401407
local_adapter_ids=inputs.local_adapter_ids,
402408
vision_inputs=inputs.vision_inputs,
403409
kv_quant_policy=kv_quant_policy,
@@ -412,15 +418,33 @@ def new(
412418
return ret
413419

414420
@classmethod
415-
def get_position_ids_1d(cls, position_ids: torch.LongTensor, seq_length: torch.LongTensor):
416-
"""Get 1d position_ids."""
417-
if position_ids.size(0) == 1 or position_ids.size(1) == 1:
418-
position_ids_1d = position_ids.flatten()
419-
else:
420-
device = position_ids.device
421-
position_ids_1d = [ids[:l] for ids, l in zip(position_ids.cpu(), seq_length.cpu())]
422-
position_ids_1d = torch.cat(position_ids_1d).to(device)
423-
return position_ids_1d
421+
def get_mask_and_position_ids(cls, inputs: ModelInputs):
422+
"""Get position ids."""
423+
q_seqlens = inputs.seq_length
424+
history_seqlens = inputs.history_lengths
425+
426+
# decoding
427+
if inputs.is_decoding:
428+
attention_mask = torch.ones_like(q_seqlens)[:, None]
429+
position_ids = history_seqlens.unsqueeze(-1).clone()
430+
position_ids = position_ids.flatten()
431+
return attention_mask, position_ids
432+
433+
num_tokens = inputs.input_ids.numel()
434+
max_q_seqlen = inputs.max_q_seqlen
435+
device = q_seqlens.device
436+
437+
# get mask
438+
mask_range = torch.arange(max_q_seqlen, device=device)[None, :]
439+
attention_mask = (mask_range < q_seqlens[:, None]).long()
440+
441+
# position_ids
442+
indices = attention_mask.long().cumsum(-1) - 1
443+
position_ids = indices + history_seqlens.unsqueeze(-1)
444+
indices[1:] += q_seqlens[:-1, None]
445+
position_ids_1d = position_ids.new_empty(num_tokens)
446+
position_ids_1d[indices.flatten()] = position_ids.flatten()
447+
return attention_mask, position_ids_1d
424448

425449

426450
class StepContextManager:

0 commit comments

Comments
 (0)