Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/design/cuda_graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren
```python
class BatchDescriptor(NamedTuple):
num_tokens: int
uniform_decode: bool = False
num_reqs: int
uniform: bool = False
has_lora: bool = False
```

where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`.
where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.

The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode.
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.

!!! note
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).
Expand Down
43 changes: 31 additions & 12 deletions tests/v1/cudagraph/test_cudagraph_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,24 @@ def _create_vllm_config(
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
mock_config.parallel_config = ParallelConfig()
mock_config.speculative_config = None # No speculative decoding
if not lora_config:
mock_config.lora_config = None
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1()

# mimic VllmConfig.__post_init__
if compilation_config.cudagraph_capture_sizes:
compilation_config.max_cudagraph_capture_size = (
compilation_config.cudagraph_capture_sizes[-1]
)

compilation_config.post_init_cudagraph_sizes()
mock_config.pad_for_cudagraph = (
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
)

return mock_config


Expand Down Expand Up @@ -109,9 +121,11 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(
num_tokens=8,
uniform_decode=False,
uniform=False,
)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False
)
rt_mode, key = dispatcher.dispatch(desc_full_exact)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
Expand All @@ -122,32 +136,37 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
assert rt_mode == CUDAGraphMode.NONE

# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=True, has_lora=False
)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.non_uniform
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.non_uniform
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
else:
assert rt_mode == CUDAGraphMode.NONE

# 3. No key match
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_no_match)
rt_mode, key = dispatcher.dispatch(
num_tokens=15, uniform_decode=False, has_lora=False
)
assert rt_mode == CUDAGraphMode.NONE
assert key is None
assert key == BatchDescriptor(num_tokens=15)

# 4. Cascade attention should have a fall back mode
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.non_uniform
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
else:
assert rt_mode == CUDAGraphMode.NONE

Expand Down
18 changes: 11 additions & 7 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,27 @@ class BatchDescriptor(NamedTuple):
"""

num_tokens: int
uniform_decode: bool = False
num_reqs: int | None = None
"""
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
the cudagraphs can handle any number of requests.
"""
uniform: bool = False
"""
True if all the requests in the batch have the same number of tokens.
"""
has_lora: bool = False
"""
Whether this batch has active LoRA adapters.
"""

@property
def non_uniform(self) -> "BatchDescriptor":
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
Return a relaxed version of current batch descriptor that is still compatible
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
)


Expand Down
21 changes: 1 addition & 20 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,31 +930,12 @@ def build(

if num_decodes > 0:
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (
self.enable_cuda_graph
and pure_decode
and num_decode_tokens <= self._decode_cudagraph_max_bs
)
if use_cudagraph:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_decode_tokens
)
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[
1 + num_decodes : 1 + num_input_tokens
].fill_(paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
1
)

else:
num_input_tokens = num_decode_tokens
num_input_tokens = num_decode_tokens

attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def _compute_prefix_caching_block_indices(
)
# -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
# -1 in the case we have a padded request (0 seq-len)
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)

return (
block_idx_last_computed_token,
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class CommonAttentionMetadata:

num_reqs: int
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
Expand Down Expand Up @@ -857,7 +858,9 @@ def split_decodes_and_prefills(
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
# 0-query len indicates a padded request; leave this at the back
# of the batch with the prefills
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)

if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
Expand Down
97 changes: 66 additions & 31 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger

logger = init_logger(__name__)


class CudagraphDispatcher:
Expand All @@ -28,33 +31,54 @@ class CudagraphDispatcher:
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.uniform_decode_query_len = (
1
if not self.vllm_config.speculative_config
else 1 + self.vllm_config.speculative_config.num_speculative_tokens
)

# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
CUDAGraphMode.PIECEWISE: set(),
CUDAGraphMode.FULL: set(),
}

not_use_piecewise_compilation = (
not self.cudagraph_mode.requires_piecewise_compilation()
)

assert (
not_use_piecewise_compilation
not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
or self.compilation_config.is_attention_compiled_piecewise()
), (
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
"cudagraph_mode piecewise cudagraphs is used, "
"and attention should be in splitting_ops or "
"inductor splitting should be used. "
f"cudagraph_mode={self.cudagraph_mode}, "
f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
f"compilation_mode={self.compilation_config.mode}, "
f"splitting_ops={self.compilation_config.splitting_ops}"
)

self.keys_initialized = False

def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)

if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
assert num_tokens_padded % uniform_decode_query_len == 0
else:
uniform_decode = False
num_reqs = min(num_tokens_padded, max_num_seqs)

return BatchDescriptor(
num_tokens=num_tokens_padded,
num_reqs=num_reqs,
uniform=uniform_decode,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: make BatchDescriptor's field also be uniform_decode to match

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im planning on refactoring this in a future PR but uniform_decode implies uniform and query_len == num_speculated_tokens + 1 while the uniform property in BatchDescriptor is a bit more flexible and implies just uniform but the query_len could be anything (just leaving it open to uniform graphs that aren't num_speculated_tokens + 1 to be used potentially in the speculator where we speculate with a query_len = 1 for additional tokens; granted the graph would still probably be num_speculated_tokens + 1 since with padded speculation we'd start with that for the first eagle head)

has_lora=has_lora,
)

def add_cudagraph_key(
self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
):
Expand All @@ -66,7 +90,9 @@ def add_cudagraph_key(
def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
):
# This should be called only after attention backend is initialized.
# This should be called only after attention backend is initialized. So we can
# get the correct cudagraph mode after backend support is resolved.
self.cudagraph_mode = cudagraph_mode

# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
Expand All @@ -86,9 +112,9 @@ def initialize_cudagraph_keys(
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(
num_tokens=bs, uniform_decode=False, has_lora=has_lora
),
self._create_padded_batch_descriptor(
bs, False, has_lora
).relax_for_mixed_batch_cudagraphs(),
)

# if decode cudagraph mode is FULL, and we don't already have mixed
Expand All @@ -109,40 +135,49 @@ def initialize_cudagraph_keys(
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(
num_tokens=bs, uniform_decode=True, has_lora=has_lora
),
self._create_padded_batch_descriptor(bs, True, has_lora),
)

self.keys_initialized = True

def dispatch(
self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False
) -> tuple[CUDAGraphMode, BatchDescriptor | None]:
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
use_cascade_attn: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
Comment on lines 143 to +149
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like a step backward to change from taking BatchDescriptor to individual params - why did you choose this?

Copy link
Collaborator Author

@LucasWilkinson LucasWilkinson Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to move to a regime where BatchDescriptor describes the exact workload captured in the cuda-graph but since some cudagraphs (PIECEWISE for example) can match different batches so the idea is to pass here all the parameters needed to find a matching cudagraph which may differ from the params that describe the graph (BatchDescriptor); id be open to renaming this CUDAGraphDescriptor, the only edge case is when no cudagraph mode is used (eager) we still return a batch descriptor that matches the current workload (basically implicitly indicating no padding)

Does that make sense?

"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
return CUDAGraphMode.NONE, None
if (
not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()

non_uniform_key = batch_descriptor.non_uniform
# if a batch use cascade attention, bypass checking full cudagraphs
if not use_cascade_attn:
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc

# otherwise, check if non-uniform key exists
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# otherwise, check if the relaxed key exists
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, relaxed_batch_desc

# also check if non-uniform key exists for more "general"
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, non_uniform_key
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc

# finally, just return no cudagraphs
return CUDAGraphMode.NONE, None
# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
Loading