You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/design/cuda_graphs.md
+5-3Lines changed: 5 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren
84
84
```python
85
85
classBatchDescriptor(NamedTuple):
86
86
num_tokens: int
87
-
uniform_decode: bool=False
87
+
num_reqs: int
88
+
uniform: bool=False
89
+
has_lora: bool=False
88
90
```
89
91
90
-
where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len`of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`.
92
+
where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.
91
93
92
-
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode.
94
+
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.
93
95
94
96
!!! note
95
97
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).
0 commit comments