@@ -1391,10 +1391,10 @@ def _prepare_inputs(
13911391 def _build_attention_metadata (
13921392 self ,
13931393 num_tokens : int ,
1394- num_tokens_padded : int ,
13951394 num_reqs : int ,
1396- num_reqs_padded : int ,
13971395 max_query_len : int ,
1396+ num_tokens_padded : int | None = None ,
1397+ num_reqs_padded : int | None = None ,
13981398 ubatch_slices : UBatchSlices | None = None ,
13991399 logits_indices : torch .Tensor | None = None ,
14001400 use_spec_decode : bool = False ,
@@ -1405,6 +1405,9 @@ def _build_attention_metadata(
14051405 """
14061406 :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
14071407 """
1408+ num_tokens_padded = num_tokens_padded or num_tokens
1409+ num_reqs_padded = num_reqs_padded or num_reqs
1410+
14081411 logits_indices_padded = None
14091412 num_logits_indices = None
14101413 if logits_indices is not None :
@@ -2789,15 +2792,14 @@ def execute_model(
27892792 )
27902793
27912794 use_spec_decode = len (scheduler_output .scheduled_spec_decode_tokens ) > 0
2792- pad_attention = cudagraph_mode == CUDAGraphMode .FULL
2793- attn_metadata , spec_decode_common_attn_metadata = (
2795+ pad_attn = cudagraph_mode == CUDAGraphMode .FULL
2796+
2797+ (attn_metadata , spec_decode_common_attn_metadata ) = (
27942798 self ._build_attention_metadata (
27952799 num_tokens = num_tokens_unpadded ,
2796- num_tokens_padded = num_tokens_padded
2797- if pad_attention
2798- else num_tokens_unpadded ,
2800+ num_tokens_padded = num_tokens_padded if pad_attn else None ,
27992801 num_reqs = num_reqs ,
2800- num_reqs_padded = num_reqs_padded if pad_attention else num_reqs ,
2802+ num_reqs_padded = num_reqs_padded if pad_attn else None ,
28012803 max_query_len = max_num_scheduled_tokens ,
28022804 ubatch_slices = ubatch_slices ,
28032805 logits_indices = logits_indices ,
@@ -3825,9 +3827,7 @@ def _dummy_run(
38253827
38263828 attn_metadata , _ = self ._build_attention_metadata (
38273829 num_tokens = num_tokens_unpadded ,
3828- num_tokens_padded = num_tokens_padded ,
38293830 num_reqs = num_reqs_padded ,
3830- num_reqs_padded = num_reqs_padded ,
38313831 max_query_len = max_query_len ,
38323832 ubatch_slices = ubatch_slices ,
38333833 for_cudagraph_capture = True ,
0 commit comments