1010 mamba_chunk_scan_combined_varlen )
1111from vllm .platforms import current_platform
1212from vllm .v1 .attention .backends .mamba2_attn import (
13- _query_start_loc_to_chunk_indices_offsets )
13+ compute_varlen_chunk_metadata )
1414
1515# Added by the IBM Team, 2024
1616
@@ -225,32 +225,30 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
225225 Y_min , final_state_min = ssd_minimal_discrete (X * dt .unsqueeze (- 1 ), A * dt ,
226226 B , C , chunk_size )
227227
228- cu_seqlens = torch .tensor ((0 , seqlen ), device = 'cuda' ).cumsum (dim = 0 )
229- seq_idx = torch .zeros (seqlen , dtype = torch .int32 , device = cu_seqlens .device )
230-
231- chunk_indices , chunk_offsets = \
232- _query_start_loc_to_chunk_indices_offsets (
233- cu_seqlens , chunk_size , cu_seqlens [- 1 ])
234-
228+ cu_seqlens = torch .tensor ((0 , seqlen ), device = "cuda" ).cumsum (dim = 0 )
229+ cu_chunk_seqlens , last_chunk_indices , seq_idx_chunks = (
230+ compute_varlen_chunk_metadata (cu_seqlens , chunk_size ))
235231 # varlen has implicit batch=1
236232 X = X .squeeze (0 )
237233 dt = dt .squeeze (0 )
238234 A = A .squeeze (0 )
239235 B = B .squeeze (0 )
240236 C = C .squeeze (0 )
241237 Y = torch .empty_like (X )
242- final_state = mamba_chunk_scan_combined_varlen (X ,
243- dt ,
244- A ,
245- B ,
246- C ,
247- chunk_size ,
248- D = None ,
249- cu_seqlens = cu_seqlens ,
250- seq_idx = seq_idx ,
251- chunk_indices = chunk_indices ,
252- chunk_offsets = chunk_offsets ,
253- out = Y )
238+ final_state = mamba_chunk_scan_combined_varlen (
239+ X ,
240+ dt ,
241+ A ,
242+ B ,
243+ C ,
244+ chunk_size ,
245+ cu_seqlens = cu_seqlens .to (torch .int32 ),
246+ cu_chunk_seqlens = cu_chunk_seqlens ,
247+ last_chunk_indices = last_chunk_indices ,
248+ seq_idx = seq_idx_chunks ,
249+ out = Y ,
250+ D = None ,
251+ )
254252
255253 # just test the last in sequence
256254 torch .testing .assert_close (Y [- 1 ], Y_min [0 , - 1 ], atol = atol , rtol = rtol )
@@ -312,14 +310,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
312310 exhausted : dict = {} # map: eg -> boolean indicating example is exhausted
313311
314312 states = None
315- for Y_min , cu_seqlens , seq_idx , (
313+ for Y_min , cu_seqlens , _token_seq_idx , (
316314 A , dt , X , B , C ) in generate_continuous_batched_examples (
317315 cases , num_examples , seqlen , last_taken , exhausted , n_heads ,
318316 d_head , itype ):
319317
320- chunk_indices , chunk_offsets = \
321- _query_start_loc_to_chunk_indices_offsets (
322- cu_seqlens , chunk_size , cu_seqlens [- 1 ])
318+ cu_chunk_seqlens , last_chunk_indices , seq_idx_chunks = (
319+ compute_varlen_chunk_metadata (cu_seqlens , chunk_size ))
323320
324321 Y = torch .empty_like (X )
325322 new_states = mamba_chunk_scan_combined_varlen (
@@ -329,13 +326,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
329326 B ,
330327 C ,
331328 chunk_size ,
329+ cu_seqlens = cu_seqlens .to (torch .int32 ),
330+ cu_chunk_seqlens = cu_chunk_seqlens ,
331+ last_chunk_indices = last_chunk_indices ,
332+ seq_idx = seq_idx_chunks ,
333+ out = Y ,
332334 D = None ,
333- cu_seqlens = cu_seqlens ,
334- seq_idx = seq_idx ,
335- chunk_indices = chunk_indices ,
336- chunk_offsets = chunk_offsets ,
337335 initial_states = states ,
338- out = Y ,
339336 )
340337
341338 # just test the last in sequence
@@ -403,9 +400,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
403400 device = X .device
404401
405402 ## full seqlen computation
406- chunk_indices , chunk_offsets = \
407- _query_start_loc_to_chunk_indices_offsets (
408- cu_seqlens , chunk_size , cu_seqlens [- 1 ])
403+ cu_chunk_seqlens , last_chunk_indices , seq_idx_chunks = (
404+ compute_varlen_chunk_metadata (cu_seqlens , chunk_size ))
409405 Y_ref = torch .empty_like (X )
410406 state_ref = mamba_chunk_scan_combined_varlen (
411407 X ,
@@ -414,13 +410,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
414410 B ,
415411 C ,
416412 chunk_size ,
413+ cu_seqlens = cu_seqlens .to (torch .int32 ),
414+ cu_chunk_seqlens = cu_chunk_seqlens ,
415+ last_chunk_indices = last_chunk_indices ,
416+ seq_idx = seq_idx_chunks ,
417+ out = Y_ref ,
417418 D = None ,
418- cu_seqlens = cu_seqlens ,
419- seq_idx = seq_idx ,
420- chunk_indices = chunk_indices ,
421- chunk_offsets = chunk_offsets ,
422419 initial_states = None ,
423- out = Y_ref ,
424420 )
425421
426422 ## chunked seqlen computation
@@ -431,10 +427,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
431427 torch .cumsum (chunked_seqlens , dim = 0 )
432428 ],
433429 dim = 0 )
434- chunked_seq_idx = torch .repeat_interleave (
435- torch .arange (len (chunked_seqlens ), device = device ),
436- chunked_seqlens ,
437- output_size = chunked_cu_seqlens [- 1 ]).to (torch .int32 )
438430 chunked_input_seq_len = chunked_cu_seqlens [- 1 ]
439431 X_chunked = torch .zeros_like (X )[:chunked_input_seq_len , ...]
440432 dt_chunked = torch .zeros_like (dt )[:chunked_input_seq_len , ...]
@@ -450,9 +442,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
450442 C_chunked [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (C , i ) # noqa: E501
451443 # fmt: on
452444
453- chunk_indices , chunk_offsets = \
454- _query_start_loc_to_chunk_indices_offsets (
455- chunked_cu_seqlens , chunk_size , chunked_cu_seqlens [- 1 ])
445+ cu_chunk_seqlens , last_chunk_indices , seq_idx_chunks = (
446+ compute_varlen_chunk_metadata (chunked_cu_seqlens , chunk_size ))
456447 Y_partial = torch .empty_like (X_chunked )
457448 partial_state = mamba_chunk_scan_combined_varlen (
458449 X_chunked ,
@@ -461,13 +452,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
461452 B_chunked ,
462453 C_chunked ,
463454 chunk_size ,
455+ cu_seqlens = chunked_cu_seqlens .to (torch .int32 ),
456+ cu_chunk_seqlens = cu_chunk_seqlens ,
457+ last_chunk_indices = last_chunk_indices ,
458+ seq_idx = seq_idx_chunks ,
459+ out = Y_partial ,
464460 D = None ,
465- cu_seqlens = chunked_cu_seqlens ,
466- seq_idx = chunked_seq_idx ,
467- chunk_indices = chunk_indices ,
468- chunk_offsets = chunk_offsets ,
469461 initial_states = None ,
470- out = Y_partial ,
471462 )
472463
473464 # remaining chunk
@@ -477,10 +468,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
477468 torch .cumsum (remaining_chunked_seqlens , dim = 0 )
478469 ],
479470 dim = 0 )
480- remaining_chunked_seq_idx = torch .repeat_interleave (
481- torch .arange (len (remaining_chunked_seqlens ), device = device ),
482- remaining_chunked_seqlens ,
483- output_size = remaining_chunked_cu_seqlens [- 1 ]).to (torch .int32 )
484471 remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens [- 1 ]
485472 # fmt: off
486473 remaining_X_chunked = torch .zeros_like (X )[:remaining_chunked_input_seq_len , ...] # noqa: E501
@@ -509,11 +496,9 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
509496 assert concat_batch_f (B_chunked , remaining_B_chunked ).equal (B )
510497 assert concat_batch_f (C_chunked , remaining_C_chunked ).equal (C )
511498
512- chunk_indices , chunk_offsets = \
513- _query_start_loc_to_chunk_indices_offsets (
514- remaining_chunked_cu_seqlens ,
515- chunk_size ,
516- remaining_chunked_cu_seqlens [- 1 ])
499+ cu_chunk_seqlens , last_chunk_indices , seq_idx_chunks = (
500+ compute_varlen_chunk_metadata (remaining_chunked_cu_seqlens ,
501+ chunk_size ))
517502
518503 Y_chunked = torch .empty_like (remaining_X_chunked )
519504 state_chunked = mamba_chunk_scan_combined_varlen (
@@ -523,13 +508,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
523508 remaining_B_chunked ,
524509 remaining_C_chunked ,
525510 chunk_size ,
511+ cu_seqlens = remaining_chunked_cu_seqlens .to (torch .int32 ),
512+ cu_chunk_seqlens = cu_chunk_seqlens ,
513+ last_chunk_indices = last_chunk_indices ,
514+ seq_idx = seq_idx_chunks ,
515+ out = Y_chunked ,
526516 D = None ,
527- cu_seqlens = remaining_chunked_cu_seqlens ,
528- seq_idx = remaining_chunked_seq_idx ,
529- chunk_indices = chunk_indices ,
530- chunk_offsets = chunk_offsets ,
531517 initial_states = partial_state ,
532- out = Y_chunked ,
533518 )
534519 Y = concat_batch_f (Y_partial , Y_chunked )
535520
0 commit comments