Skip to content

Commit a32fd83

Browse files
LucasWilkinsonkitaekatt
authored andcommitted
[Core] Refactor padding logic and pad for CUDA graphs before attention metadata building (vllm-project#28579)
1 parent a1ab702 commit a32fd83

File tree

10 files changed

+401
-283
lines changed

10 files changed

+401
-283
lines changed

docs/design/cuda_graphs.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren
8484
```python
8585
class BatchDescriptor(NamedTuple):
8686
num_tokens: int
87-
uniform_decode: bool = False
87+
num_reqs: int
88+
uniform: bool = False
89+
has_lora: bool = False
8890
```
8991

90-
where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`.
92+
where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.
9193

92-
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode.
94+
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.
9395

9496
!!! note
9597
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).

tests/v1/cudagraph/test_cudagraph_dispatch.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,24 @@ def _create_vllm_config(
4242
mock_config.compilation_config = compilation_config
4343
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
4444
mock_config.parallel_config = ParallelConfig()
45+
mock_config.speculative_config = None # No speculative decoding
4546
if not lora_config:
4647
mock_config.lora_config = None
4748
# Mimic the behavior of VllmConfig.__post_init__()
4849
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
4950
compilation_config.set_splitting_ops_for_v1()
5051

52+
# mimic VllmConfig.__post_init__
53+
if compilation_config.cudagraph_capture_sizes:
54+
compilation_config.max_cudagraph_capture_size = (
55+
compilation_config.cudagraph_capture_sizes[-1]
56+
)
57+
58+
compilation_config.post_init_cudagraph_sizes()
59+
mock_config.pad_for_cudagraph = (
60+
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
61+
)
62+
5163
return mock_config
5264

5365

@@ -109,9 +121,11 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
109121
# 1. non-uniform batch, size in cudagraph size list
110122
desc_full_exact = BatchDescriptor(
111123
num_tokens=8,
112-
uniform_decode=False,
124+
uniform=False,
125+
)
126+
rt_mode, key = dispatcher.dispatch(
127+
num_tokens=8, uniform_decode=False, has_lora=False
113128
)
114-
rt_mode, key = dispatcher.dispatch(desc_full_exact)
115129
if cudagraph_mode_str == "FULL":
116130
assert rt_mode == CUDAGraphMode.FULL
117131
assert key == desc_full_exact
@@ -122,32 +136,37 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
122136
assert rt_mode == CUDAGraphMode.NONE
123137

124138
# 2. uniform decode batch, size in cudagraph size list
125-
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
126-
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
139+
desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
140+
rt_mode, key = dispatcher.dispatch(
141+
num_tokens=8, uniform_decode=True, has_lora=False
142+
)
127143
if cudagraph_mode_str == "FULL":
128144
assert rt_mode == CUDAGraphMode.FULL
129-
assert key == desc_uniform_exact.non_uniform
145+
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
130146
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
131147
assert rt_mode == CUDAGraphMode.FULL
132148
assert key == desc_uniform_exact
133149
elif cudagraph_mode_str == "PIECEWISE":
134150
assert rt_mode == CUDAGraphMode.PIECEWISE
135-
assert key == desc_uniform_exact.non_uniform
151+
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
136152
else:
137153
assert rt_mode == CUDAGraphMode.NONE
138154

139155
# 3. No key match
140-
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
141-
rt_mode, key = dispatcher.dispatch(desc_no_match)
156+
rt_mode, key = dispatcher.dispatch(
157+
num_tokens=15, uniform_decode=False, has_lora=False
158+
)
142159
assert rt_mode == CUDAGraphMode.NONE
143-
assert key is None
160+
assert key == BatchDescriptor(num_tokens=15)
144161

145162
# 4. Cascade attention should have a fall back mode
146-
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
147-
rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True)
163+
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
164+
rt_mode, key = dispatcher.dispatch(
165+
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
166+
)
148167
if "PIECEWISE" in cudagraph_mode_str: # string contains check
149168
assert rt_mode == CUDAGraphMode.PIECEWISE
150-
assert key == desc_full_exact.non_uniform
169+
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
151170
else:
152171
assert rt_mode == CUDAGraphMode.NONE
153172

vllm/forward_context.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,27 @@ class BatchDescriptor(NamedTuple):
3535
"""
3636

3737
num_tokens: int
38-
uniform_decode: bool = False
38+
num_reqs: int | None = None
3939
"""
40-
False can also be used for an uniform decode batch to dispatch to the
41-
cudagraph supporting non-uniform batches.
40+
Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
41+
the cudagraphs can handle any number of requests.
42+
"""
43+
uniform: bool = False
44+
"""
45+
True if all the requests in the batch have the same number of tokens.
4246
"""
4347
has_lora: bool = False
4448
"""
4549
Whether this batch has active LoRA adapters.
4650
"""
4751

48-
@property
49-
def non_uniform(self) -> "BatchDescriptor":
52+
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
5053
"""
51-
Return a non-uniform version of current batch descriptor.
54+
Return a relaxed version of current batch descriptor that is still compatible
55+
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
5256
"""
5357
return BatchDescriptor(
54-
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
58+
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
5559
)
5660

5761

vllm/v1/attention/backends/flashinfer.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -930,31 +930,12 @@ def build(
930930

931931
if num_decodes > 0:
932932
pure_decode = num_prefills == 0
933-
# possible required padding for cudagraph replay
934933
use_cudagraph = (
935934
self.enable_cuda_graph
936935
and pure_decode
937936
and num_decode_tokens <= self._decode_cudagraph_max_bs
938937
)
939-
if use_cudagraph:
940-
num_input_tokens = self.vllm_config.pad_for_cudagraph(
941-
num_decode_tokens
942-
)
943-
# Carefully fulfill the padding region with reasonable value
944-
# on cpu.
945-
# Make sure paged_kv_indptr_cpu is not decreasing
946-
self.paged_kv_indptr_cpu[
947-
1 + num_decodes : 1 + num_input_tokens
948-
].fill_(paged_kv_indptr_cpu[-1])
949-
# Fill the remaining paged_kv_last_page_len_cpu with 1.
950-
# This is because flashinfer treats 0 as a full page
951-
# instead of empty.
952-
self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
953-
1
954-
)
955-
956-
else:
957-
num_input_tokens = num_decode_tokens
938+
num_input_tokens = num_decode_tokens
958939

959940
attn_metadata.decode_wrapper = self._get_decode_wrapper(
960941
num_input_tokens, use_cudagraph

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def _compute_prefix_caching_block_indices(
107107
)
108108
# -1 in case it's non-computed and causes later issues with indexing
109109
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
110+
# -1 in the case we have a padded request (0 seq-len)
111+
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)
110112

111113
return (
112114
block_idx_last_computed_token,

vllm/v1/attention/backends/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class CommonAttentionMetadata:
7272

7373
num_reqs: int
7474
"""Number of requests"""
75+
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
7576
num_actual_tokens: int
7677
"""Total number of tokens in batch"""
7778
max_query_len: int
@@ -857,7 +858,9 @@ def split_decodes_and_prefills(
857858
if require_uniform:
858859
is_prefill = query_lens != query_lens[0]
859860
else:
860-
is_prefill = query_lens > decode_threshold
861+
# 0-query len indicates a padded request; leave this at the back
862+
# of the batch with the prefills
863+
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
861864

862865
if not torch.any(is_prefill):
863866
return num_reqs, 0, num_tokens, 0

vllm/v1/cudagraph_dispatcher.py

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
from vllm.config import CUDAGraphMode, VllmConfig
66
from vllm.forward_context import BatchDescriptor
7+
from vllm.logger import init_logger
8+
9+
logger = init_logger(__name__)
710

811

912
class CudagraphDispatcher:
@@ -28,33 +31,54 @@ class CudagraphDispatcher:
2831
def __init__(self, vllm_config: VllmConfig):
2932
self.vllm_config = vllm_config
3033
self.compilation_config = vllm_config.compilation_config
31-
self.cudagraph_mode = self.compilation_config.cudagraph_mode
34+
self.uniform_decode_query_len = (
35+
1
36+
if not self.vllm_config.speculative_config
37+
else 1 + self.vllm_config.speculative_config.num_speculative_tokens
38+
)
3239

3340
# Dict to store valid cudagraph dispatching keys.
3441
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
3542
CUDAGraphMode.PIECEWISE: set(),
3643
CUDAGraphMode.FULL: set(),
3744
}
3845

39-
not_use_piecewise_compilation = (
40-
not self.cudagraph_mode.requires_piecewise_compilation()
41-
)
42-
4346
assert (
44-
not_use_piecewise_compilation
47+
not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
4548
or self.compilation_config.is_attention_compiled_piecewise()
4649
), (
4750
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
4851
"cudagraph_mode piecewise cudagraphs is used, "
4952
"and attention should be in splitting_ops or "
5053
"inductor splitting should be used. "
51-
f"cudagraph_mode={self.cudagraph_mode}, "
54+
f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
5255
f"compilation_mode={self.compilation_config.mode}, "
5356
f"splitting_ops={self.compilation_config.splitting_ops}"
5457
)
5558

5659
self.keys_initialized = False
5760

61+
def _create_padded_batch_descriptor(
62+
self, num_tokens: int, uniform_decode: bool, has_lora: bool
63+
) -> BatchDescriptor:
64+
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
65+
uniform_decode_query_len = self.uniform_decode_query_len
66+
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
67+
68+
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
69+
num_reqs = num_tokens_padded // uniform_decode_query_len
70+
assert num_tokens_padded % uniform_decode_query_len == 0
71+
else:
72+
uniform_decode = False
73+
num_reqs = min(num_tokens_padded, max_num_seqs)
74+
75+
return BatchDescriptor(
76+
num_tokens=num_tokens_padded,
77+
num_reqs=num_reqs,
78+
uniform=uniform_decode,
79+
has_lora=has_lora,
80+
)
81+
5882
def add_cudagraph_key(
5983
self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
6084
):
@@ -66,7 +90,9 @@ def add_cudagraph_key(
6690
def initialize_cudagraph_keys(
6791
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
6892
):
69-
# This should be called only after attention backend is initialized.
93+
# This should be called only after attention backend is initialized. So we can
94+
# get the correct cudagraph mode after backend support is resolved.
95+
self.cudagraph_mode = cudagraph_mode
7096

7197
# LoRA activation cases to specialize the cuda graphs on
7298
if self.vllm_config.lora_config:
@@ -86,9 +112,9 @@ def initialize_cudagraph_keys(
86112
):
87113
self.add_cudagraph_key(
88114
cudagraph_mode.mixed_mode(),
89-
BatchDescriptor(
90-
num_tokens=bs, uniform_decode=False, has_lora=has_lora
91-
),
115+
self._create_padded_batch_descriptor(
116+
bs, False, has_lora
117+
).relax_for_mixed_batch_cudagraphs(),
92118
)
93119

94120
# if decode cudagraph mode is FULL, and we don't already have mixed
@@ -109,40 +135,49 @@ def initialize_cudagraph_keys(
109135
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
110136
self.add_cudagraph_key(
111137
CUDAGraphMode.FULL,
112-
BatchDescriptor(
113-
num_tokens=bs, uniform_decode=True, has_lora=has_lora
114-
),
138+
self._create_padded_batch_descriptor(bs, True, has_lora),
115139
)
140+
116141
self.keys_initialized = True
117142

118143
def dispatch(
119-
self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False
120-
) -> tuple[CUDAGraphMode, BatchDescriptor | None]:
144+
self,
145+
num_tokens: int,
146+
uniform_decode: bool,
147+
has_lora: bool,
148+
use_cascade_attn: bool = False,
149+
) -> tuple[CUDAGraphMode, BatchDescriptor]:
121150
"""
122151
Given conditions(e.g.,batch descriptor and if using cascade attention),
123152
dispatch to a cudagraph runtime mode and the valid batch descriptor.
124153
A new batch descriptor is returned as we might dispatch a uniform batch
125154
to a graph that supports a more general batch (uniform to non-uniform).
126155
"""
127-
# if not initialized, just skip dispatching.
128-
if not self.keys_initialized:
129-
return CUDAGraphMode.NONE, None
156+
if (
157+
not self.keys_initialized
158+
or self.cudagraph_mode == CUDAGraphMode.NONE
159+
or num_tokens > self.compilation_config.max_cudagraph_capture_size
160+
):
161+
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
162+
163+
batch_desc = self._create_padded_batch_descriptor(
164+
num_tokens, uniform_decode, has_lora
165+
)
166+
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
130167

131-
non_uniform_key = batch_descriptor.non_uniform
132-
# if a batch use cascade attention, bypass checking full cudagraphs
133168
if not use_cascade_attn:
134169
# check if key exists for full cudagraph
135-
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
136-
return CUDAGraphMode.FULL, batch_descriptor
170+
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
171+
return CUDAGraphMode.FULL, batch_desc
137172

138-
# otherwise, check if non-uniform key exists
139-
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
140-
return CUDAGraphMode.FULL, non_uniform_key
173+
# otherwise, check if the relaxed key exists
174+
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
175+
return CUDAGraphMode.FULL, relaxed_batch_desc
141176

142-
# also check if non-uniform key exists for more "general"
177+
# also check if the relaxed key exists for more "general"
143178
# piecewise cudagraph
144-
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
145-
return CUDAGraphMode.PIECEWISE, non_uniform_key
179+
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
180+
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
146181

147-
# finally, just return no cudagraphs
148-
return CUDAGraphMode.NONE, None
182+
# finally, just return no cudagraphs and a trivial batch descriptor
183+
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

0 commit comments

Comments
 (0)