@@ -1379,10 +1379,10 @@ def _prepare_inputs(
13791379 def _build_attention_metadata (
13801380 self ,
13811381 num_tokens : int ,
1382- num_tokens_padded : int ,
13831382 num_reqs : int ,
1384- num_reqs_padded : int ,
13851383 max_query_len : int ,
1384+ num_tokens_padded : int | None = None ,
1385+ num_reqs_padded : int | None = None ,
13861386 ubatch_slices : UBatchSlices | None = None ,
13871387 logits_indices : torch .Tensor | None = None ,
13881388 use_spec_decode : bool = False ,
@@ -1393,6 +1393,9 @@ def _build_attention_metadata(
13931393 """
13941394 :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
13951395 """
1396+ num_tokens_padded = num_tokens_padded or num_tokens
1397+ num_reqs_padded = num_reqs_padded or num_reqs
1398+
13961399 logits_indices_padded = None
13971400 num_logits_indices = None
13981401 if logits_indices is not None :
@@ -2791,15 +2794,14 @@ def execute_model(
27912794 )
27922795
27932796 use_spec_decode = len (scheduler_output .scheduled_spec_decode_tokens ) > 0
2794- pad_attention = cudagraph_mode == CUDAGraphMode .FULL
2795- attn_metadata , spec_decode_common_attn_metadata = (
2797+ pad_attn = cudagraph_mode == CUDAGraphMode .FULL
2798+
2799+ (attn_metadata , spec_decode_common_attn_metadata ) = (
27962800 self ._build_attention_metadata (
27972801 num_tokens = num_tokens_unpadded ,
2798- num_tokens_padded = num_tokens_padded
2799- if pad_attention
2800- else num_tokens_unpadded ,
2802+ num_tokens_padded = num_tokens_padded if pad_attn else None ,
28012803 num_reqs = num_reqs ,
2802- num_reqs_padded = num_reqs_padded if pad_attention else num_reqs ,
2804+ num_reqs_padded = num_reqs_padded if pad_attn else None ,
28032805 max_query_len = max_num_scheduled_tokens ,
28042806 ubatch_slices = ubatch_slices ,
28052807 logits_indices = logits_indices ,
@@ -3817,9 +3819,7 @@ def _dummy_run(
38173819
38183820 attn_metadata , _ = self ._build_attention_metadata (
38193821 num_tokens = num_tokens_unpadded ,
3820- num_tokens_padded = num_tokens_padded ,
38213822 num_reqs = num_reqs_padded ,
3822- num_reqs_padded = num_reqs_padded ,
38233823 max_query_len = max_query_len ,
38243824 ubatch_slices = ubatch_slices ,
38253825 for_cudagraph_capture = True ,
0 commit comments