Skip to content

Commit a7a04ba

Browse files
cleanup
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent a55ac68 commit a7a04ba

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,10 +1379,10 @@ def _prepare_inputs(
13791379
def _build_attention_metadata(
13801380
self,
13811381
num_tokens: int,
1382-
num_tokens_padded: int,
13831382
num_reqs: int,
1384-
num_reqs_padded: int,
13851383
max_query_len: int,
1384+
num_tokens_padded: int | None = None,
1385+
num_reqs_padded: int | None = None,
13861386
ubatch_slices: UBatchSlices | None = None,
13871387
logits_indices: torch.Tensor | None = None,
13881388
use_spec_decode: bool = False,
@@ -1393,6 +1393,9 @@ def _build_attention_metadata(
13931393
"""
13941394
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
13951395
"""
1396+
num_tokens_padded = num_tokens_padded or num_tokens
1397+
num_reqs_padded = num_reqs_padded or num_reqs
1398+
13961399
logits_indices_padded = None
13971400
num_logits_indices = None
13981401
if logits_indices is not None:
@@ -2791,15 +2794,14 @@ def execute_model(
27912794
)
27922795

27932796
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
2794-
pad_attention = cudagraph_mode == CUDAGraphMode.FULL
2795-
attn_metadata, spec_decode_common_attn_metadata = (
2797+
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
2798+
2799+
(attn_metadata, spec_decode_common_attn_metadata) = (
27962800
self._build_attention_metadata(
27972801
num_tokens=num_tokens_unpadded,
2798-
num_tokens_padded=num_tokens_padded
2799-
if pad_attention
2800-
else num_tokens_unpadded,
2802+
num_tokens_padded=num_tokens_padded if pad_attn else None,
28012803
num_reqs=num_reqs,
2802-
num_reqs_padded=num_reqs_padded if pad_attention else num_reqs,
2804+
num_reqs_padded=num_reqs_padded if pad_attn else None,
28032805
max_query_len=max_num_scheduled_tokens,
28042806
ubatch_slices=ubatch_slices,
28052807
logits_indices=logits_indices,
@@ -3817,9 +3819,7 @@ def _dummy_run(
38173819

38183820
attn_metadata, _ = self._build_attention_metadata(
38193821
num_tokens=num_tokens_unpadded,
3820-
num_tokens_padded=num_tokens_padded,
38213822
num_reqs=num_reqs_padded,
3822-
num_reqs_padded=num_reqs_padded,
38233823
max_query_len=max_query_len,
38243824
ubatch_slices=ubatch_slices,
38253825
for_cudagraph_capture=True,

0 commit comments

Comments
 (0)