Skip to content

Commit 2b58a28

Browse files
cleanup
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent a81c2e3 commit 2b58a28

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,10 +1391,10 @@ def _prepare_inputs(
13911391
def _build_attention_metadata(
13921392
self,
13931393
num_tokens: int,
1394-
num_tokens_padded: int,
13951394
num_reqs: int,
1396-
num_reqs_padded: int,
13971395
max_query_len: int,
1396+
num_tokens_padded: int | None = None,
1397+
num_reqs_padded: int | None = None,
13981398
ubatch_slices: UBatchSlices | None = None,
13991399
logits_indices: torch.Tensor | None = None,
14001400
use_spec_decode: bool = False,
@@ -1405,6 +1405,9 @@ def _build_attention_metadata(
14051405
"""
14061406
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
14071407
"""
1408+
num_tokens_padded = num_tokens_padded or num_tokens
1409+
num_reqs_padded = num_reqs_padded or num_reqs
1410+
14081411
logits_indices_padded = None
14091412
num_logits_indices = None
14101413
if logits_indices is not None:
@@ -2789,15 +2792,14 @@ def execute_model(
27892792
)
27902793

27912794
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
2792-
pad_attention = cudagraph_mode == CUDAGraphMode.FULL
2793-
attn_metadata, spec_decode_common_attn_metadata = (
2795+
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
2796+
2797+
(attn_metadata, spec_decode_common_attn_metadata) = (
27942798
self._build_attention_metadata(
27952799
num_tokens=num_tokens_unpadded,
2796-
num_tokens_padded=num_tokens_padded
2797-
if pad_attention
2798-
else num_tokens_unpadded,
2800+
num_tokens_padded=num_tokens_padded if pad_attn else None,
27992801
num_reqs=num_reqs,
2800-
num_reqs_padded=num_reqs_padded if pad_attention else num_reqs,
2802+
num_reqs_padded=num_reqs_padded if pad_attn else None,
28012803
max_query_len=max_num_scheduled_tokens,
28022804
ubatch_slices=ubatch_slices,
28032805
logits_indices=logits_indices,
@@ -3825,9 +3827,7 @@ def _dummy_run(
38253827

38263828
attn_metadata, _ = self._build_attention_metadata(
38273829
num_tokens=num_tokens_unpadded,
3828-
num_tokens_padded=num_tokens_padded,
38293830
num_reqs=num_reqs_padded,
3830-
num_reqs_padded=num_reqs_padded,
38313831
max_query_len=max_query_len,
38323832
ubatch_slices=ubatch_slices,
38333833
for_cudagraph_capture=True,

0 commit comments

Comments
 (0)