@@ -131,6 +131,9 @@ class ModelInputs:
131
131
block_offsets : torch .LongTensor
132
132
is_decoding : bool
133
133
num_ignored_history : torch .LongTensor
134
+ max_q_seqlen : int
135
+ max_kv_seqlen : int
136
+ sum_kv_seqlen : int
134
137
local_adapter_ids : torch .LongTensor = None
135
138
vision_inputs : VisionModelInputs = None
136
139
cross_length : torch .LongTensor = None
@@ -143,6 +146,8 @@ def update(self, input_ids: torch.LongTensor):
143
146
"""Update input ids."""
144
147
assert self .is_decoding
145
148
self .history_lengths = self .history_lengths + 1
149
+ self .max_kv_seqlen += 1
150
+ self .sum_kv_seqlen += self .seq_length .numel ()
146
151
if input_ids .dim () == 1 :
147
152
input_ids = input_ids [None , :]
148
153
self .input_ids = input_ids
@@ -214,6 +219,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int):
214
219
max_seq_len = self .seq_length [0 ].item ()
215
220
ret = []
216
221
start = 0
222
+ max_kv_seqlen = self .max_kv_seqlen
217
223
218
224
# for mllama
219
225
history_cross_length = self .history_cross_length
@@ -227,13 +233,17 @@ def __make_next_vision_inputs(flatten_mms: List, start: int):
227
233
vision_inputs = None
228
234
end = min (max_seq_len , start + split_size )
229
235
236
+ max_q_seqlen = end - start
230
237
inp = ModelInputs (
231
238
input_ids = self .input_ids [:, start :end ],
232
239
seq_length = input_ids .new_tensor ([end - start ]),
233
240
block_offsets = self .block_offsets ,
234
241
history_lengths = self .history_lengths + start ,
235
242
is_decoding = self .is_decoding ,
236
243
num_ignored_history = self .num_ignored_history ,
244
+ max_q_seqlen = max_q_seqlen ,
245
+ max_kv_seqlen = max_kv_seqlen ,
246
+ sum_kv_seqlen = max_kv_seqlen ,
237
247
local_adapter_ids = self .local_adapter_ids ,
238
248
vision_inputs = vision_inputs ,
239
249
model_metas = self .model_metas ,
@@ -242,6 +252,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int):
242
252
)
243
253
ret .append (inp )
244
254
history_cross_length = cross_length
255
+ max_kv_seqlen += max_q_seqlen
245
256
246
257
start = end
247
258
@@ -291,6 +302,9 @@ def make_dummy(cls,
291
302
block_offsets = block_offsets ,
292
303
is_decoding = is_decoding ,
293
304
num_ignored_history = num_ignored_history ,
305
+ max_q_seqlen = 1 ,
306
+ max_kv_seqlen = 1 ,
307
+ sum_kv_seqlen = batch_size ,
294
308
)
295
309
296
310
def log_info (self ):
@@ -316,6 +330,7 @@ class StepContext:
316
330
q_start_loc : torch .LongTensor
317
331
kv_caches : List
318
332
is_decoding : bool
333
+ sum_kv_seqlen : int
319
334
local_adapter_ids : torch .LongTensor = None
320
335
input_embeddings : torch .Tensor = None
321
336
input_embedding_indexing : torch .Tensor = None
@@ -348,7 +363,6 @@ def new(
348
363
"""
349
364
q_seqlens = inputs .seq_length
350
365
history_seqlens = inputs .history_lengths
351
- device = q_seqlens .device
352
366
353
367
input_multimodals = None
354
368
if inputs .vision_inputs is not None :
@@ -360,16 +374,9 @@ def new(
360
374
input_embeddings , input_embedding_indexing = \
361
375
inputs .vision_inputs .get_inputs (history_seqlens , q_seqlens )
362
376
363
- # kv_seqlens
364
- if inputs .is_decoding :
365
- attention_mask = torch .ones_like (q_seqlens )[:, None ]
366
- position_ids = history_seqlens .unsqueeze (- 1 ).clone ()
367
- else :
368
- max_q_seqlen = q_seqlens .max ().item ()
369
- mask_range = torch .arange (max_q_seqlen , device = device )[None , :]
370
- attention_mask = (mask_range < q_seqlens [:, None ]).long ()
371
- position_ids = attention_mask .long ().cumsum (- 1 ) - 1
372
- position_ids += history_seqlens .unsqueeze (- 1 )
377
+ # position ids
378
+ attention_mask , position_ids = cls .get_mask_and_position_ids (inputs )
379
+ position_ids = position_ids [None ] # [num_tokens] -> [1, num_tokens]
373
380
q_start_loc = q_seqlens .cumsum (0 ) - q_seqlens
374
381
375
382
# cross
@@ -378,8 +385,6 @@ def new(
378
385
if inputs .cross_length is not None :
379
386
cross_kv_seqlens = (inputs .cross_length + inputs .history_cross_length )
380
387
381
- # position ids 1d
382
- position_ids = cls .get_position_ids_1d (position_ids , q_seqlens )[None ]
383
388
# seq_len + history_length
384
389
kv_seqlens = q_seqlens + history_seqlens
385
390
kv_seqlens -= inputs .num_ignored_history
@@ -398,6 +403,7 @@ def new(
398
403
q_start_loc = q_start_loc ,
399
404
kv_caches = kv_caches ,
400
405
is_decoding = inputs .is_decoding ,
406
+ sum_kv_seqlen = inputs .sum_kv_seqlen ,
401
407
local_adapter_ids = inputs .local_adapter_ids ,
402
408
vision_inputs = inputs .vision_inputs ,
403
409
kv_quant_policy = kv_quant_policy ,
@@ -412,15 +418,33 @@ def new(
412
418
return ret
413
419
414
420
@classmethod
415
- def get_position_ids_1d (cls , position_ids : torch .LongTensor , seq_length : torch .LongTensor ):
416
- """Get 1d position_ids."""
417
- if position_ids .size (0 ) == 1 or position_ids .size (1 ) == 1 :
418
- position_ids_1d = position_ids .flatten ()
419
- else :
420
- device = position_ids .device
421
- position_ids_1d = [ids [:l ] for ids , l in zip (position_ids .cpu (), seq_length .cpu ())]
422
- position_ids_1d = torch .cat (position_ids_1d ).to (device )
423
- return position_ids_1d
421
+ def get_mask_and_position_ids (cls , inputs : ModelInputs ):
422
+ """Get position ids."""
423
+ q_seqlens = inputs .seq_length
424
+ history_seqlens = inputs .history_lengths
425
+
426
+ # decoding
427
+ if inputs .is_decoding :
428
+ attention_mask = torch .ones_like (q_seqlens )[:, None ]
429
+ position_ids = history_seqlens .unsqueeze (- 1 ).clone ()
430
+ position_ids = position_ids .flatten ()
431
+ return attention_mask , position_ids
432
+
433
+ num_tokens = inputs .input_ids .numel ()
434
+ max_q_seqlen = inputs .max_q_seqlen
435
+ device = q_seqlens .device
436
+
437
+ # get mask
438
+ mask_range = torch .arange (max_q_seqlen , device = device )[None , :]
439
+ attention_mask = (mask_range < q_seqlens [:, None ]).long ()
440
+
441
+ # position_ids
442
+ indices = attention_mask .long ().cumsum (- 1 ) - 1
443
+ position_ids = indices + history_seqlens .unsqueeze (- 1 )
444
+ indices [1 :] += q_seqlens [:- 1 , None ]
445
+ position_ids_1d = position_ids .new_empty (num_tokens )
446
+ position_ids_1d [indices .flatten ()] = position_ids .flatten ()
447
+ return attention_mask , position_ids_1d
424
448
425
449
426
450
class StepContextManager :
0 commit comments