-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Core] Refactor padding logic and pad for CUDA graphs before attention metadata building #28579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5c545a1
c5a55a2
5d60dbd
ca34797
04958ac
da1717f
c74c4a4
368c806
265259e
4268434
62ad56f
89f0ca7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,9 @@ | |
|
|
||
| from vllm.config import CUDAGraphMode, VllmConfig | ||
| from vllm.forward_context import BatchDescriptor | ||
| from vllm.logger import init_logger | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class CudagraphDispatcher: | ||
|
|
@@ -28,33 +31,54 @@ class CudagraphDispatcher: | |
| def __init__(self, vllm_config: VllmConfig): | ||
| self.vllm_config = vllm_config | ||
| self.compilation_config = vllm_config.compilation_config | ||
| self.cudagraph_mode = self.compilation_config.cudagraph_mode | ||
| self.uniform_decode_query_len = ( | ||
| 1 | ||
| if not self.vllm_config.speculative_config | ||
| else 1 + self.vllm_config.speculative_config.num_speculative_tokens | ||
| ) | ||
|
|
||
| # Dict to store valid cudagraph dispatching keys. | ||
| self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = { | ||
| CUDAGraphMode.PIECEWISE: set(), | ||
| CUDAGraphMode.FULL: set(), | ||
| } | ||
|
|
||
| not_use_piecewise_compilation = ( | ||
| not self.cudagraph_mode.requires_piecewise_compilation() | ||
| ) | ||
|
|
||
| assert ( | ||
| not_use_piecewise_compilation | ||
| not self.compilation_config.cudagraph_mode.requires_piecewise_compilation() | ||
| or self.compilation_config.is_attention_compiled_piecewise() | ||
| ), ( | ||
| "Compilation mode should be CompilationMode.VLLM_COMPILE when " | ||
| "cudagraph_mode piecewise cudagraphs is used, " | ||
| "and attention should be in splitting_ops or " | ||
| "inductor splitting should be used. " | ||
| f"cudagraph_mode={self.cudagraph_mode}, " | ||
| f"cudagraph_mode={self.compilation_config.cudagraph_mode}, " | ||
| f"compilation_mode={self.compilation_config.mode}, " | ||
| f"splitting_ops={self.compilation_config.splitting_ops}" | ||
| ) | ||
|
|
||
| self.keys_initialized = False | ||
|
|
||
| def _create_padded_batch_descriptor( | ||
| self, num_tokens: int, uniform_decode: bool, has_lora: bool | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) -> BatchDescriptor: | ||
| max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs | ||
| uniform_decode_query_len = self.uniform_decode_query_len | ||
| num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens) | ||
|
|
||
| if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): | ||
| num_reqs = num_tokens_padded // uniform_decode_query_len | ||
| assert num_tokens_padded % uniform_decode_query_len == 0 | ||
| else: | ||
| uniform_decode = False | ||
| num_reqs = min(num_tokens_padded, max_num_seqs) | ||
|
|
||
| return BatchDescriptor( | ||
| num_tokens=num_tokens_padded, | ||
| num_reqs=num_reqs, | ||
| uniform=uniform_decode, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: make BatchDescriptor's field also be uniform_decode to match
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Im planning on refactoring this in a future PR but |
||
| has_lora=has_lora, | ||
| ) | ||
|
|
||
| def add_cudagraph_key( | ||
| self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor | ||
| ): | ||
|
|
@@ -66,7 +90,9 @@ def add_cudagraph_key( | |
| def initialize_cudagraph_keys( | ||
| self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int | ||
| ): | ||
| # This should be called only after attention backend is initialized. | ||
| # This should be called only after attention backend is initialized. So we can | ||
| # get the correct cudagraph mode after backend support is resolved. | ||
| self.cudagraph_mode = cudagraph_mode | ||
|
|
||
| # LoRA activation cases to specialize the cuda graphs on | ||
| if self.vllm_config.lora_config: | ||
|
|
@@ -86,9 +112,9 @@ def initialize_cudagraph_keys( | |
| ): | ||
| self.add_cudagraph_key( | ||
| cudagraph_mode.mixed_mode(), | ||
| BatchDescriptor( | ||
| num_tokens=bs, uniform_decode=False, has_lora=has_lora | ||
| ), | ||
| self._create_padded_batch_descriptor( | ||
| bs, False, has_lora | ||
| ).relax_for_mixed_batch_cudagraphs(), | ||
| ) | ||
|
|
||
| # if decode cudagraph mode is FULL, and we don't already have mixed | ||
|
|
@@ -109,40 +135,49 @@ def initialize_cudagraph_keys( | |
| for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): | ||
| self.add_cudagraph_key( | ||
| CUDAGraphMode.FULL, | ||
| BatchDescriptor( | ||
| num_tokens=bs, uniform_decode=True, has_lora=has_lora | ||
| ), | ||
| self._create_padded_batch_descriptor(bs, True, has_lora), | ||
| ) | ||
|
|
||
| self.keys_initialized = True | ||
|
|
||
| def dispatch( | ||
| self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False | ||
| ) -> tuple[CUDAGraphMode, BatchDescriptor | None]: | ||
| self, | ||
| num_tokens: int, | ||
| uniform_decode: bool, | ||
| has_lora: bool, | ||
| use_cascade_attn: bool = False, | ||
| ) -> tuple[CUDAGraphMode, BatchDescriptor]: | ||
|
Comment on lines
143
to
+149
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like a step backward to change from taking BatchDescriptor to individual params - why did you choose this?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea is to move to a regime where Does that make sense? |
||
| """ | ||
| Given conditions(e.g.,batch descriptor and if using cascade attention), | ||
| dispatch to a cudagraph runtime mode and the valid batch descriptor. | ||
| A new batch descriptor is returned as we might dispatch a uniform batch | ||
| to a graph that supports a more general batch (uniform to non-uniform). | ||
| """ | ||
| # if not initialized, just skip dispatching. | ||
| if not self.keys_initialized: | ||
| return CUDAGraphMode.NONE, None | ||
| if ( | ||
| not self.keys_initialized | ||
| or self.cudagraph_mode == CUDAGraphMode.NONE | ||
| or num_tokens > self.compilation_config.max_cudagraph_capture_size | ||
| ): | ||
| return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) | ||
|
|
||
| batch_desc = self._create_padded_batch_descriptor( | ||
| num_tokens, uniform_decode, has_lora | ||
| ) | ||
| relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() | ||
|
|
||
| non_uniform_key = batch_descriptor.non_uniform | ||
| # if a batch use cascade attention, bypass checking full cudagraphs | ||
| if not use_cascade_attn: | ||
| # check if key exists for full cudagraph | ||
| if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: | ||
| return CUDAGraphMode.FULL, batch_descriptor | ||
| if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: | ||
| return CUDAGraphMode.FULL, batch_desc | ||
|
|
||
| # otherwise, check if non-uniform key exists | ||
| if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: | ||
| return CUDAGraphMode.FULL, non_uniform_key | ||
| # otherwise, check if the relaxed key exists | ||
| if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: | ||
| return CUDAGraphMode.FULL, relaxed_batch_desc | ||
|
|
||
| # also check if non-uniform key exists for more "general" | ||
| # also check if the relaxed key exists for more "general" | ||
| # piecewise cudagraph | ||
| if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: | ||
| return CUDAGraphMode.PIECEWISE, non_uniform_key | ||
| if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: | ||
| return CUDAGraphMode.PIECEWISE, relaxed_batch_desc | ||
|
|
||
| # finally, just return no cudagraphs | ||
| return CUDAGraphMode.NONE, None | ||
| # finally, just return no cudagraphs and a trivial batch descriptor | ||
| return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) | ||
Uh oh!
There was an error while loading. Please reload this page.