diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index 66bf3b27d1f5..7baadf8ba23c 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren ```python class BatchDescriptor(NamedTuple): num_tokens: int - uniform_decode: bool = False + num_reqs: int + uniform: bool = False + has_lora: bool = False ``` -where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`. +where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`. -The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode. +The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. !!! note The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs). diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index bb953e5c70c8..314e7094ef97 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -42,12 +42,24 @@ def _create_vllm_config( mock_config.compilation_config = compilation_config mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) mock_config.parallel_config = ParallelConfig() + mock_config.speculative_config = None # No speculative decoding if not lora_config: mock_config.lora_config = None # Mimic the behavior of VllmConfig.__post_init__() if compilation_config.mode == CompilationMode.VLLM_COMPILE: compilation_config.set_splitting_ops_for_v1() + # mimic VllmConfig.__post_init__ + if compilation_config.cudagraph_capture_sizes: + compilation_config.max_cudagraph_capture_size = ( + compilation_config.cudagraph_capture_sizes[-1] + ) + + compilation_config.post_init_cudagraph_sizes() + mock_config.pad_for_cudagraph = ( + lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size] + ) + return mock_config @@ -109,9 +121,11 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): # 1. non-uniform batch, size in cudagraph size list desc_full_exact = BatchDescriptor( num_tokens=8, - uniform_decode=False, + uniform=False, + ) + rt_mode, key = dispatcher.dispatch( + num_tokens=8, uniform_decode=False, has_lora=False ) - rt_mode, key = dispatcher.dispatch(desc_full_exact) if cudagraph_mode_str == "FULL": assert rt_mode == CUDAGraphMode.FULL assert key == desc_full_exact @@ -122,32 +136,37 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): assert rt_mode == CUDAGraphMode.NONE # 2. uniform decode batch, size in cudagraph size list - desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True) - rt_mode, key = dispatcher.dispatch(desc_uniform_exact) + desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True) + rt_mode, key = dispatcher.dispatch( + num_tokens=8, uniform_decode=True, has_lora=False + ) if cudagraph_mode_str == "FULL": assert rt_mode == CUDAGraphMode.FULL - assert key == desc_uniform_exact.non_uniform + assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs() elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]: assert rt_mode == CUDAGraphMode.FULL assert key == desc_uniform_exact elif cudagraph_mode_str == "PIECEWISE": assert rt_mode == CUDAGraphMode.PIECEWISE - assert key == desc_uniform_exact.non_uniform + assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs() else: assert rt_mode == CUDAGraphMode.NONE # 3. No key match - desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False) - rt_mode, key = dispatcher.dispatch(desc_no_match) + rt_mode, key = dispatcher.dispatch( + num_tokens=15, uniform_decode=False, has_lora=False + ) assert rt_mode == CUDAGraphMode.NONE - assert key is None + assert key == BatchDescriptor(num_tokens=15) # 4. Cascade attention should have a fall back mode - desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) - rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True) + desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False) + rt_mode, key = dispatcher.dispatch( + num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True + ) if "PIECEWISE" in cudagraph_mode_str: # string contains check assert rt_mode == CUDAGraphMode.PIECEWISE - assert key == desc_full_exact.non_uniform + assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs() else: assert rt_mode == CUDAGraphMode.NONE diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 7cb490e391ab..635419bc7cad 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -35,23 +35,27 @@ class BatchDescriptor(NamedTuple): """ num_tokens: int - uniform_decode: bool = False + num_reqs: int | None = None """ - False can also be used for an uniform decode batch to dispatch to the - cudagraph supporting non-uniform batches. + Number of requests in the batch. Can be None for PIECEWISE cudagraphs where + the cudagraphs can handle any number of requests. + """ + uniform: bool = False + """ + True if all the requests in the batch have the same number of tokens. """ has_lora: bool = False """ Whether this batch has active LoRA adapters. """ - @property - def non_uniform(self) -> "BatchDescriptor": + def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor": """ - Return a non-uniform version of current batch descriptor. + Return a relaxed version of current batch descriptor that is still compatible + with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs). """ return BatchDescriptor( - self.num_tokens, uniform_decode=False, has_lora=self.has_lora + self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora ) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8159f4096107..dbd72b298b1f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -930,31 +930,12 @@ def build( if num_decodes > 0: pure_decode = num_prefills == 0 - # possible required padding for cudagraph replay use_cudagraph = ( self.enable_cuda_graph and pure_decode and num_decode_tokens <= self._decode_cudagraph_max_bs ) - if use_cudagraph: - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_decode_tokens - ) - # Carefully fulfill the padding region with reasonable value - # on cpu. - # Make sure paged_kv_indptr_cpu is not decreasing - self.paged_kv_indptr_cpu[ - 1 + num_decodes : 1 + num_input_tokens - ].fill_(paged_kv_indptr_cpu[-1]) - # Fill the remaining paged_kv_last_page_len_cpu with 1. - # This is because flashinfer treats 0 as a full page - # instead of empty. - self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_( - 1 - ) - - else: - num_input_tokens = num_decode_tokens + num_input_tokens = num_decode_tokens attn_metadata.decode_wrapper = self._get_decode_wrapper( num_input_tokens, use_cudagraph diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 0d875565fc99..a9705db59f19 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -107,6 +107,8 @@ def _compute_prefix_caching_block_indices( ) # -1 in case it's non-computed and causes later issues with indexing block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) + # -1 in the case we have a padded request (0 seq-len) + block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0) return ( block_idx_last_computed_token, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index cebfe8a3ff04..18e91fd4fd6a 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -72,6 +72,7 @@ class CommonAttentionMetadata: num_reqs: int """Number of requests""" + # TODO(lucas): rename to num_tokens since it may be padded and this is misleading num_actual_tokens: int """Total number of tokens in batch""" max_query_len: int @@ -857,7 +858,9 @@ def split_decodes_and_prefills( if require_uniform: is_prefill = query_lens != query_lens[0] else: - is_prefill = query_lens > decode_threshold + # 0-query len indicates a padded request; leave this at the back + # of the batch with the prefills + is_prefill = (query_lens > decode_threshold) | (query_lens == 0) if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index b480ac78f23c..ef0f8d9e6745 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -4,6 +4,9 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor +from vllm.logger import init_logger + +logger = init_logger(__name__) class CudagraphDispatcher: @@ -28,7 +31,11 @@ class CudagraphDispatcher: def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config - self.cudagraph_mode = self.compilation_config.cudagraph_mode + self.uniform_decode_query_len = ( + 1 + if not self.vllm_config.speculative_config + else 1 + self.vllm_config.speculative_config.num_speculative_tokens + ) # Dict to store valid cudagraph dispatching keys. self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = { @@ -36,25 +43,42 @@ def __init__(self, vllm_config: VllmConfig): CUDAGraphMode.FULL: set(), } - not_use_piecewise_compilation = ( - not self.cudagraph_mode.requires_piecewise_compilation() - ) - assert ( - not_use_piecewise_compilation + not self.compilation_config.cudagraph_mode.requires_piecewise_compilation() or self.compilation_config.is_attention_compiled_piecewise() ), ( "Compilation mode should be CompilationMode.VLLM_COMPILE when " "cudagraph_mode piecewise cudagraphs is used, " "and attention should be in splitting_ops or " "inductor splitting should be used. " - f"cudagraph_mode={self.cudagraph_mode}, " + f"cudagraph_mode={self.compilation_config.cudagraph_mode}, " f"compilation_mode={self.compilation_config.mode}, " f"splitting_ops={self.compilation_config.splitting_ops}" ) self.keys_initialized = False + def _create_padded_batch_descriptor( + self, num_tokens: int, uniform_decode: bool, has_lora: bool + ) -> BatchDescriptor: + max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs + uniform_decode_query_len = self.uniform_decode_query_len + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens) + + if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): + num_reqs = num_tokens_padded // uniform_decode_query_len + assert num_tokens_padded % uniform_decode_query_len == 0 + else: + uniform_decode = False + num_reqs = min(num_tokens_padded, max_num_seqs) + + return BatchDescriptor( + num_tokens=num_tokens_padded, + num_reqs=num_reqs, + uniform=uniform_decode, + has_lora=has_lora, + ) + def add_cudagraph_key( self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor ): @@ -66,7 +90,9 @@ def add_cudagraph_key( def initialize_cudagraph_keys( self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int ): - # This should be called only after attention backend is initialized. + # This should be called only after attention backend is initialized. So we can + # get the correct cudagraph mode after backend support is resolved. + self.cudagraph_mode = cudagraph_mode # LoRA activation cases to specialize the cuda graphs on if self.vllm_config.lora_config: @@ -86,9 +112,9 @@ def initialize_cudagraph_keys( ): self.add_cudagraph_key( cudagraph_mode.mixed_mode(), - BatchDescriptor( - num_tokens=bs, uniform_decode=False, has_lora=has_lora - ), + self._create_padded_batch_descriptor( + bs, False, has_lora + ).relax_for_mixed_batch_cudagraphs(), ) # if decode cudagraph mode is FULL, and we don't already have mixed @@ -109,40 +135,49 @@ def initialize_cudagraph_keys( for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): self.add_cudagraph_key( CUDAGraphMode.FULL, - BatchDescriptor( - num_tokens=bs, uniform_decode=True, has_lora=has_lora - ), + self._create_padded_batch_descriptor(bs, True, has_lora), ) + self.keys_initialized = True def dispatch( - self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False - ) -> tuple[CUDAGraphMode, BatchDescriptor | None]: + self, + num_tokens: int, + uniform_decode: bool, + has_lora: bool, + use_cascade_attn: bool = False, + ) -> tuple[CUDAGraphMode, BatchDescriptor]: """ Given conditions(e.g.,batch descriptor and if using cascade attention), dispatch to a cudagraph runtime mode and the valid batch descriptor. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). """ - # if not initialized, just skip dispatching. - if not self.keys_initialized: - return CUDAGraphMode.NONE, None + if ( + not self.keys_initialized + or self.cudagraph_mode == CUDAGraphMode.NONE + or num_tokens > self.compilation_config.max_cudagraph_capture_size + ): + return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) + + batch_desc = self._create_padded_batch_descriptor( + num_tokens, uniform_decode, has_lora + ) + relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() - non_uniform_key = batch_descriptor.non_uniform - # if a batch use cascade attention, bypass checking full cudagraphs if not use_cascade_attn: # check if key exists for full cudagraph - if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, batch_descriptor + if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, batch_desc - # otherwise, check if non-uniform key exists - if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, non_uniform_key + # otherwise, check if the relaxed key exists + if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, relaxed_batch_desc - # also check if non-uniform key exists for more "general" + # also check if the relaxed key exists for more "general" # piecewise cudagraph - if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: - return CUDAGraphMode.PIECEWISE, non_uniform_key + if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: + return CUDAGraphMode.PIECEWISE, relaxed_batch_desc - # finally, just return no cudagraphs - return CUDAGraphMode.NONE, None + # finally, just return no cudagraphs and a trivial batch descriptor + return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 464fbf11a21a..064f2f0360cb 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -9,6 +9,7 @@ from vllm.distributed.parallel_state import get_dp_group from vllm.logger import init_logger from vllm.v1.worker.ubatch_utils import ( + UBatchSlice, UBatchSlices, check_ubatch_thresholds, create_ubatch_slices, @@ -88,6 +89,17 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch return num_tokens_across_dp.cpu() +# This just pads the second ubatch slice out to the total number of tokens +# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding. +def _pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int): + padded_second_ubatch_slice = slice( + ubatch_slices[1].token_slice.start, num_total_tokens + ) + ubatch_slices[1] = UBatchSlice( + padded_second_ubatch_slice, padded_second_ubatch_slice + ) + + def _synchronize_dp_ranks( num_tokens_unpadded: int, num_tokens_padded: int, @@ -220,11 +232,14 @@ def coordinate_batch_across_dp( # to the second ubatch in pad_out_ubatch_slice after attention # metadata creation assert num_tokens_after_padding is not None - token_split_point = int(num_tokens_after_padding[0].item()) // 2 + num_tokens_padded = int(num_tokens_after_padding[0].item()) + token_split_point = int(num_tokens_padded) // 2 assert num_scheduled_tokens_per_request is not None ubatch_slices = create_ubatch_slices( num_scheduled_tokens_per_request, token_split_point ) + ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded) + assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded return (ubatch_slices, num_tokens_after_padding) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d3c61794f8b0..4f167fd16a1d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -150,7 +150,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.ubatch_utils import ( - UBatchSlice, UBatchSlices, check_ubatch_thresholds, ) @@ -1238,17 +1237,13 @@ def _prepare_inputs( self, scheduler_output: "SchedulerOutput", num_scheduled_tokens: np.ndarray, - max_num_scheduled_tokens: int, ) -> tuple[ torch.Tensor, SpecDecodeMetadata | None, - UBatchSlices | None, - torch.Tensor | None, ]: """ :return: tuple[ logits_indices, spec_decode_metadata, - ubatch_slices, num_tokens_across_dp, ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -1363,28 +1358,6 @@ def _prepare_inputs( self.query_start_loc.copy_to_gpu() query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] - num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded) - uniform_decode = ( - max_num_scheduled_tokens == self.uniform_decode_query_len - ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - - # Disable DP padding when running eager to avoid excessive padding when - # running prefills. This lets us set enforce_eager on the prefiller in - # a P/D setup and still use CUDA graphs (enabled by this padding) on the - # decoder. - allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - - ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( - num_tokens_unpadded=num_tokens_unpadded, - parallel_config=self.parallel_config, - allow_microbatching=True, - allow_dp_padding=allow_dp_padding, - num_tokens_padded=num_tokens_padded, - uniform_decode=uniform_decode, - num_scheduled_tokens_per_request=num_scheduled_tokens, - ) - self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens ) @@ -1485,15 +1458,15 @@ def _prepare_inputs( return ( logits_indices, spec_decode_metadata, - ubatch_slices, - num_tokens_across_dp, ) def _build_attention_metadata( self, - total_num_scheduled_tokens: int, - max_num_scheduled_tokens: int, + num_tokens: int, num_reqs: int, + max_query_len: int, + num_tokens_padded: int | None = None, + num_reqs_padded: int | None = None, ubatch_slices: UBatchSlices | None = None, logits_indices: torch.Tensor | None = None, use_spec_decode: bool = False, @@ -1504,6 +1477,9 @@ def _build_attention_metadata( """ :return: tuple[attn_metadata, spec_decode_common_attn_metadata] """ + num_tokens_padded = num_tokens_padded or num_tokens + num_reqs_padded = num_reqs_padded or num_reqs + logits_indices_padded = None num_logits_indices = None if logits_indices is not None: @@ -1521,28 +1497,13 @@ def _build_attention_metadata( self.dcp_rank, self.parallel_config.cp_kv_cache_interleave_size, ) - self.dcp_local_seq_lens.copy_to_gpu(num_reqs) + self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0) + self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded) attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: attn_metadata = [dict() for _ in range(len(ubatch_slices))] - # Used in the below loop - query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] - query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1] - seq_lens = self.seq_lens.gpu[:num_reqs] - seq_lens_cpu = self.seq_lens.cpu[:num_reqs] - num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ - :num_reqs - ] - - dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None - if self.dcp_world_size > 1: - dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs] - dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs] - - spec_decode_common_attn_metadata = None - if for_cudagraph_capture: # For some attention backends (e.g. FA) with sliding window models we need # to make sure the backend see a max_seq_len that is larger to the sliding @@ -1558,6 +1519,22 @@ def _build_attention_metadata( self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() + # Used in the below loop, uses padded shapes + query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1] + query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1] + seq_lens = self.seq_lens.gpu[:num_reqs_padded] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs_padded] + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs_padded + ] + + dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None + if self.dcp_world_size > 1: + dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded] + dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs_padded] + + spec_decode_common_attn_metadata = None + # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_gid, kv_cache_group in enumerate( @@ -1566,30 +1543,31 @@ def _build_attention_metadata( encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens( num_scheduled_tokens or {}, kv_cache_group.kv_cache_spec, - num_reqs, + num_reqs_padded, ) if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to # create a dummy block table and slot mapping for them. blk_table_tensor = torch.zeros( - (num_reqs, 1), + (num_tokens_padded, 1), dtype=torch.int32, device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens,), + (num_tokens_padded,), dtype=torch.int64, device=self.device, ) else: blk_table = self.input_batch.block_table[kv_cache_gid] - blk_table_tensor = blk_table.get_device_tensor(num_reqs) - slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] + blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded) + slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID + slot_mapping[num_tokens:num_tokens_padded].fill_(-1) + blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -1597,9 +1575,9 @@ def _build_attention_metadata( seq_lens=seq_lens, seq_lens_cpu=seq_lens_cpu, num_computed_tokens_cpu=num_computed_tokens_cpu, - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, + num_actual_tokens=num_tokens_padded, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, max_seq_len=max_seq_len, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, @@ -1630,9 +1608,11 @@ def _build_attention_metadata( extra_attn_metadata_args = {} if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_accepted_tokens=self.num_accepted_tokens.gpu[ + :num_reqs_padded + ], num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ - :num_reqs + :num_reqs_padded ], ) @@ -1676,6 +1656,7 @@ def _build_attention_metadata( def _compute_cascade_attn_prefix_lens( self, num_scheduled_tokens: np.ndarray, + num_computed_tokens: np.ndarray, num_common_prefix_blocks: list[int], ) -> list[list[int]] | None: """ @@ -1698,6 +1679,7 @@ def _compute_cascade_attn_prefix_lens( # 0 if cascade attention should not be used cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, + num_computed_tokens, num_common_prefix_blocks[kv_cache_gid], attn_group.kv_cache_spec, attn_group.get_metadata_builder(), @@ -1710,6 +1692,7 @@ def _compute_cascade_attn_prefix_lens( def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, + num_computed_tokens: np.ndarray, num_common_prefix_blocks: int, kv_cache_spec: KVCacheSpec, attn_metadata_builder: AttentionMetadataBuilder, @@ -1776,10 +1759,7 @@ def _compute_cascade_attn_prefix_len( # and the second kernel will get an empty input. While this is not # a fundamental problem, our current implementation does not support # this case. - num_reqs = len(num_scheduled_tokens) - common_prefix_len = min( - common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min() - ) + common_prefix_len = min(common_prefix_len, num_computed_tokens.min()) # common_prefix_len should be a multiple of the block size. common_prefix_len = ( common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size @@ -2333,19 +2313,6 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: log_stats=self.parallel_config.eplb_config.log_balancedness, ) - # This is where the second ubatch is adjusted to account for the padding. - # Should be called after attention metadata creation. This just pads - # the second ubatch slice out to the total number of tokens - # (num_tokens + padding) - @staticmethod - def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int): - padded_second_ubatch_slice = slice( - ubatch_slices[1].token_slice.start, num_total_tokens - ) - ubatch_slices[1] = UBatchSlice( - padded_second_ubatch_slice, padded_second_ubatch_slice - ) - def _pool( self, hidden_states: torch.Tensor, @@ -2390,18 +2357,7 @@ def _pool( pooler_output=pooler_output, ) - def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: - if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and hasattr(self, "cudagraph_batch_sizes") - and self.cudagraph_batch_sizes - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] - ): - # Use CUDA graphs. - # Add padding to the batch size. - return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) - - # Eager mode. + def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size @@ -2737,6 +2693,87 @@ def _model_forward( **model_kwargs, ) + def _determine_batch_execution_and_padding( + self, + num_tokens: int, + num_reqs: int, + num_scheduled_tokens_np: np.ndarray, + max_num_scheduled_tokens: int, + use_cascade_attn: bool, + allow_microbatching: bool = True, + force_eager: bool = False, + # For cudagraph capture TODO(lucas): Refactor how we capture cudagraphs (will + # be improved in model runner v2) + force_uniform_decode: bool | None = None, + force_has_lora: bool | None = None, + ) -> tuple[ + CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None + ]: + num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) + uniform_decode = ( + ( + (max_num_scheduled_tokens == self.uniform_decode_query_len) + and (num_tokens_padded == max_num_scheduled_tokens * num_reqs) + ) + if force_uniform_decode is None + else force_uniform_decode + ) + + has_lora = ( + len(self.input_batch.lora_id_to_lora_request) > 0 + if force_has_lora is None + else force_has_lora + ) + + dispatch_cudagraph = ( + lambda num_tokens: self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, + has_lora=has_lora, + use_cascade_attn=use_cascade_attn, + uniform_decode=uniform_decode, + ) + if not force_eager + else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) + ) + + cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded) + num_tokens_padded = batch_descriptor.num_tokens + + # Extra coordination when running data-parallel since we need to coordinate + # across ranks + ubatch_slices, num_tokens_across_dp = None, None + if self.vllm_config.parallel_config.data_parallel_size > 1: + # Disable DP padding when running eager to avoid excessive padding when + # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller + # in a P/D setup and still use CUDA graphs (enabled by this padding) on the + # decoder. + allow_dp_padding = ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ) + + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens_padded, + parallel_config=self.parallel_config, + allow_microbatching=allow_microbatching, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens_np, + ) + + # Extract DP padding if there is any + if num_tokens_across_dp is not None: + dp_rank = self.parallel_config.data_parallel_rank + num_tokens_padded = int(num_tokens_across_dp[dp_rank].item()) + + # Re-dispatch with DP padding + cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded) + # Assert to make sure the agreed upon token count is correct otherwise + # num_tokens_across_dp will no-longer be valid + assert batch_descriptor.num_tokens == num_tokens_padded + + return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp + @torch.inference_mode() def execute_model( self, @@ -2789,7 +2826,7 @@ def execute_model( # returns True. before returning early here we call # dummy run to ensure coordinate_batch_across_dp # is called into to avoid out of sync issues. - self._dummy_run(self._get_num_input_tokens(1)) + self._dummy_run(1) if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT @@ -2808,36 +2845,63 @@ def execute_model( tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens ( logits_indices, spec_decode_metadata, - ubatch_slices, - num_tokens_across_dp, ) = self._prepare_inputs( - scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens + scheduler_output, + num_scheduled_tokens_np, ) cascade_attn_prefix_lens = None # Disable cascade attention when using microbatching (DBO) - if self.cascade_attn_enabled and ubatch_slices is None: + if self.cascade_attn_enabled and not self.parallel_config.enable_dbo: # Pre-compute cascade attention prefix lengths - # NOTE: Must be AFTER _prepare_inputs uses self.input_batch state cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], scheduler_output.num_common_prefix_blocks, ) - # TODO(lucas): move cudagraph dispatching here: - # https://github.com/vllm-project/vllm/issues/23789 + ( + cudagraph_mode, + batch_desc, + ubatch_slices, + num_tokens_across_dp, + ) = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=cascade_attn_prefix_lens is not None, + ) + + logger.debug( + "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " + "ubatch_slices: %s, num_tokens_across_dp: %s", + cudagraph_mode, + batch_desc, + ubatch_slices, + num_tokens_across_dp, + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 - attn_metadata, spec_decode_common_attn_metadata = ( + pad_attn = cudagraph_mode == CUDAGraphMode.FULL + + (attn_metadata, spec_decode_common_attn_metadata) = ( self._build_attention_metadata( - total_num_scheduled_tokens=total_num_scheduled_tokens, - max_num_scheduled_tokens=max_num_scheduled_tokens, + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, ubatch_slices=ubatch_slices, logits_indices=logits_indices, use_spec_decode=use_spec_decode, @@ -2846,49 +2910,22 @@ def execute_model( ) ) - dp_rank = self.parallel_config.data_parallel_rank - if ubatch_slices: - assert num_tokens_across_dp is not None - num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) - self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) - elif num_tokens_across_dp is not None: - num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) - else: - num_input_tokens = self._get_num_input_tokens( - scheduler_output.total_num_scheduled_tokens - ) - - ( - input_ids, - inputs_embeds, - positions, - intermediate_tensors, - model_kwargs, - ec_connector_output, - ) = self._preprocess( - scheduler_output, num_input_tokens, intermediate_tensors - ) - - uniform_decode = ( - max_num_scheduled_tokens == self.uniform_decode_query_len - ) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - batch_desc = BatchDescriptor( - num_tokens=num_input_tokens, - uniform_decode=uniform_decode, - has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, - ) - cudagraph_runtime_mode, batch_descriptor = ( - self.cudagraph_dispatcher.dispatch( - batch_desc, - use_cascade_attn=cascade_attn_prefix_lens is not None, - ) + ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) = self._preprocess( + scheduler_output, num_tokens_padded, intermediate_tensors ) # Set cudagraph mode to none if calc_kv_scales is true. # KV scales calculation involves dynamic operations that are incompatible # with CUDA graph capture. if self.calculate_kv_scales: - cudagraph_runtime_mode = CUDAGraphMode.NONE + cudagraph_mode = CUDAGraphMode.NONE # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False @@ -2898,10 +2935,10 @@ def execute_model( set_forward_context( attn_metadata, self.vllm_config, - num_tokens=num_input_tokens, + num_tokens=num_tokens_padded, num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_desc, ubatch_slices=ubatch_slices, ), record_function_or_nullcontext("gpu_model_runner: forward"), @@ -2951,7 +2988,7 @@ def execute_model( if not get_pp_group().is_last_rank: all_gather_tensors = { "residual": not is_residual_scattered_for_sp( - self.vllm_config, num_input_tokens + self.vllm_config, num_tokens_padded ) } get_pp_group().send_tensor_dict( @@ -3840,52 +3877,44 @@ def _dummy_run( assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) + num_tokens_unpadded = int(num_scheduled_tokens.sum()) + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - # Disable DP padding when running eager - allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - - # We currently only microbatch if the number of tokens is - # over a certain threshold. - ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( - num_tokens_unpadded=total_num_scheduled_tokens, - parallel_config=self.vllm_config.parallel_config, - allow_microbatching=allow_microbatching, - allow_dp_padding=allow_dp_padding, - num_tokens_padded=total_num_scheduled_tokens, - uniform_decode=uniform_decode, - num_scheduled_tokens_per_request=num_scheduled_tokens, - ) - num_tokens_after_padding = num_tokens - if num_tokens_across_dp is not None: - dp_rank = self.parallel_config.data_parallel_rank - num_tokens_after_padding = int(num_tokens_across_dp[dp_rank]) - - # filter out the valid batch descriptor - _cg_mode, batch_descriptor = ( - self.cudagraph_dispatcher.dispatch( - BatchDescriptor( - num_tokens=num_tokens_after_padding, - uniform_decode=uniform_decode, - has_lora=activate_lora and self.lora_config is not None, - ) + _cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = ( + self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens, + max_num_scheduled_tokens=max_query_len, + use_cascade_attn=False, + allow_microbatching=allow_microbatching, + force_eager=is_profile + or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + # `force_uniform_decode` is used for cudagraph capture; because for + # capturing mixed prefill-decode batches, we sometimes use + # num_tokens == num_reqs which looks like a uniform decode batch to the + # dispatcher; but we actually want to capture a piecewise cudagraph + force_uniform_decode=uniform_decode, + # `force_has_lora` is used for cudagraph capture; because LoRA is + # activated later in the context manager, but we need to know the + # LoRA state when determining the batch descriptor for capture + force_has_lora=activate_lora, ) - if not is_profile - else (CUDAGraphMode.NONE, None) ) - if cudagraph_runtime_mode is not None: - # we allow forcing NONE when the dispatcher disagrees to support - # warm ups for cudagraph capture - assert ( - cudagraph_runtime_mode == CUDAGraphMode.NONE - or cudagraph_runtime_mode == _cg_mode - ), ( - f"Cudagraph runtime mode mismatch at dummy_run. " - f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." - ) + + if cudagraph_runtime_mode is None: + cudagraph_runtime_mode = _cudagraph_mode else: - cudagraph_runtime_mode = _cg_mode + assert cudagraph_runtime_mode == _cudagraph_mode, ( + f"Cudagraph runtime mode mismatch in dummy_run. " + f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}." + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) attn_metadata: PerLayerAttnMetadata | None = None @@ -3908,9 +3937,9 @@ def _dummy_run( self.query_start_loc.copy_to_gpu() attn_metadata, _ = self._build_attention_metadata( - total_num_scheduled_tokens=num_tokens, - max_num_scheduled_tokens=max_query_len, - num_reqs=num_reqs, + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, ubatch_slices=ubatch_slices, for_cudagraph_capture=True, ) @@ -3923,29 +3952,29 @@ def _dummy_run( remove_lora, ): # Make sure padding doesn't exceed max_num_tokens - assert num_tokens_after_padding <= self.max_num_tokens - model_kwargs = self._init_model_kwargs(num_tokens_after_padding) + assert num_tokens_padded <= self.max_num_tokens + model_kwargs = self._init_model_kwargs(num_tokens_padded) if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding] + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] model_kwargs = { **model_kwargs, **self._dummy_mm_kwargs(num_reqs), } elif self.enable_prompt_embeds: input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding] - model_kwargs = self._init_model_kwargs(num_tokens_after_padding) + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + model_kwargs = self._init_model_kwargs(num_tokens_padded) else: - input_ids = self.input_ids.gpu[:num_tokens_after_padding] + input_ids = self.input_ids.gpu[:num_tokens_padded] inputs_embeds = None if self.uses_mrope: - positions = self.mrope_positions.gpu[:, :num_tokens_after_padding] + positions = self.mrope_positions.gpu[:, :num_tokens_padded] elif self.uses_xdrope_dim > 0: - positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding] + positions = self.xdrope_positions.gpu[:, :num_tokens_padded] else: - positions = self.positions.gpu[:num_tokens_after_padding] + positions = self.positions.gpu[:num_tokens_padded] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -3960,26 +3989,26 @@ def _dummy_run( ) intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens_after_padding, None, False + num_tokens_padded, None, False ) if ubatch_slices is not None: # Adjust values to reflect a single ubatch. # TODO(sage,lucas): this is cruft that should be addressed in # the padding refactor. - num_tokens_after_padding = ubatch_slices[0].num_tokens + num_tokens_padded = ubatch_slices[0].num_tokens if num_tokens_across_dp is not None: - num_tokens_across_dp[:] = num_tokens_after_padding + num_tokens_across_dp[:] = num_tokens_padded with ( self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, - num_tokens=num_tokens_after_padding, + num_tokens=num_tokens_padded, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, + batch_descriptor=batch_desc, ubatch_slices=ubatch_slices, ), ): @@ -4705,8 +4734,7 @@ def _check_and_update_cudagraph_mode( # Trigger cudagraph dispatching keys initialization after # resolved cudagraph mode. - cudagraph_mode = self.compilation_config.cudagraph_mode - assert cudagraph_mode is not None + self.compilation_config.cudagraph_mode = cudagraph_mode self.cudagraph_dispatcher.initialize_cudagraph_keys( cudagraph_mode, self.uniform_decode_query_len ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6a4bfde5f972..d0c6091ce2a6 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -8,12 +8,13 @@ from types import NoneType from typing import TYPE_CHECKING, Any, cast +import numpy as np import torch import torch.distributed import torch.nn as nn import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, @@ -487,6 +488,7 @@ def compile_or_warm_up_model(self) -> None: hidden_states, last_hidden_states = self.model_runner._dummy_run( num_tokens=max_num_reqs, skip_eplb=True, + cudagraph_runtime_mode=CUDAGraphMode.NONE, ) if self.model_runner.is_pooling_model: self.model_runner._dummy_pooler_run(hidden_states) @@ -534,12 +536,39 @@ def execute_model( intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens) - all_gather_tensors = { - "residual": not is_residual_scattered_for_sp( - self.vllm_config, num_input_tokens + all_gather_tensors = {} + compilation_config = self.vllm_config.compilation_config + parallel_config = self.vllm_config.parallel_config + + if ( + parallel_config.pipeline_parallel_size > 1 + and compilation_config.pass_config.enable_sequence_parallelism + and forward_pass + ): + # currently only supported by V1 GPUModelRunner + assert isinstance(self.model_runner, GPUModelRunner) + num_scheduled_tokens_np = np.array( + list(scheduler_output.num_scheduled_tokens.values()), + dtype=np.int32, ) - } + # TODO(lucas): This is pretty gross; ideally we should only ever call + # `_determine_batch_execution_and_padding` once (will get called again + # in `execute_model`) but this requires a larger refactor of PP. + _, batch_desc, _, _ = ( + self.model_runner._determine_batch_execution_and_padding( + num_tokens=num_scheduled_tokens, + num_reqs=len(num_scheduled_tokens_np), + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=num_scheduled_tokens_np.max(), + use_cascade_attn=False, # TODO(lucas): Handle cascade attention + ) + ) + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp( + self.vllm_config, batch_desc.num_tokens + ) + } + if forward_pass and not get_pp_group().is_first_rank: tensor_dict = get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group(),