diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 1cbd0fe56be6..f60861e3489d 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -13,7 +13,7 @@ split_attn_metadata, split_decodes_and_prefills, ) -from vllm.v1.worker.ubatch_utils import create_ubatch_slices +from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices @pytest.fixture @@ -294,8 +294,14 @@ def test_prefill_split_across_ubatches( qsl_np = common.query_start_loc_cpu.numpy() num_tokens = common.num_actual_tokens - ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point) - assert len(ubatch_slices) == 2 + ubatch_slices, _ = maybe_create_ubatch_slices( + True, + num_scheduled_tokens, + num_tokens, + batch_spec.batch_size, + split_point=split_point, + ) + assert ubatch_slices is not None and len(ubatch_slices) == 2 first_meta = _make_metadata_with_slice(ubatch_slices[0], common) second_meta = _make_metadata_with_slice(ubatch_slices[1], common) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1c7845a14b74..31428db2d3af 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1258,7 +1258,7 @@ def _pad_batch_across_dp( num_tokens_padded: int, ) -> tuple[int, torch.Tensor]: # TODO(Flechman): support DBO ubatching - ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp( + should_ubatch, num_toks_across_dp = coordinate_batch_across_dp( num_tokens_unpadded=num_tokens_unpadded, parallel_config=self.vllm_config.parallel_config, allow_microbatching=False, @@ -1267,7 +1267,7 @@ def _pad_batch_across_dp( uniform_decode=None, num_scheduled_tokens_per_request=None, ) - assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE" + assert not should_ubatch, "DBO ubatching not implemented for EAGLE" num_tokens_dp_padded = num_tokens_padded if num_toks_across_dp is not None: diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 6539d72d81cb..5da55d740c34 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import numpy as np import torch import torch.distributed as dist @@ -9,10 +10,7 @@ from vllm.distributed.parallel_state import get_dp_group from vllm.logger import init_logger from vllm.v1.worker.ubatch_utils import ( - UBatchSlice, - UBatchSlices, check_ubatch_thresholds, - create_ubatch_slices, is_second_ubatch_empty, ) @@ -91,20 +89,6 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch return num_tokens_across_dp.cpu() -# This just pads the second ubatch slice out to the total number of tokens -# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding. -def _pad_out_ubatch_slice( - ubatch_slices: UBatchSlices, num_total_tokens: int -) -> UBatchSlices: - padded_second_token_slice = slice( - ubatch_slices[1].token_slice.start, num_total_tokens - ) - ubatch_slices[1] = UBatchSlice( - ubatch_slices[1].request_slice, padded_second_token_slice - ) - return ubatch_slices - - def _synchronize_dp_ranks( num_tokens_unpadded: int, num_tokens_padded: int, @@ -175,7 +159,7 @@ def coordinate_batch_across_dp( num_tokens_padded: int | None = None, uniform_decode: bool | None = None, num_scheduled_tokens_per_request: np.ndarray | None = None, -) -> tuple[UBatchSlices | None, torch.Tensor | None]: +) -> tuple[bool, torch.Tensor | None]: """ Coordinates amongst all DP ranks to determine if and how the full batch should be split into microbatches. @@ -204,7 +188,7 @@ def coordinate_batch_across_dp( """ if parallel_config.data_parallel_size == 1: # Early exit. - return None, None + return False, None # If the caller has explicitly enabled microbatching. should_attempt_ubatching = False @@ -228,23 +212,4 @@ def coordinate_batch_across_dp( parallel_config, ) - # Don't microbatch unless every other DP worker is also microbatching - if not should_ubatch: - return (None, num_tokens_after_padding) - - # This doesn't actually pad the ubatch slices. It just initializes the - # split point to the padded value so that padding can be applied - # to the second ubatch in pad_out_ubatch_slice after attention - # metadata creation - assert num_tokens_after_padding is not None - num_tokens_padded = int(num_tokens_after_padding[0].item()) - token_split_point = int(num_tokens_padded) // 2 - - assert num_scheduled_tokens_per_request is not None - ubatch_slices = create_ubatch_slices( - num_scheduled_tokens_per_request, token_split_point - ) - ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded) - assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded - - return (ubatch_slices, num_tokens_after_padding) + return (should_ubatch, num_tokens_after_padding) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3f043e3b2648..7fc4d4d5c348 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -153,6 +153,7 @@ from vllm.v1.worker.ubatch_utils import ( UBatchSlices, check_ubatch_thresholds, + maybe_create_ubatch_slices, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp @@ -2758,7 +2759,7 @@ def _determine_batch_execution_and_padding( ) -> tuple[ CUDAGraphMode, BatchDescriptor, - UBatchSlices | None, + bool, torch.Tensor | None, CUDAGraphStat | None, ]: @@ -2794,7 +2795,7 @@ def _determine_batch_execution_and_padding( # Extra coordination when running data-parallel since we need to coordinate # across ranks - ubatch_slices, num_tokens_across_dp = None, None + should_ubatch, num_tokens_across_dp = False, None if self.vllm_config.parallel_config.data_parallel_size > 1: # Disable DP padding when running eager to avoid excessive padding when # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller @@ -2804,8 +2805,8 @@ def _determine_batch_execution_and_padding( self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE ) - ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( - num_tokens_unpadded=num_tokens_padded, + should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens, parallel_config=self.parallel_config, allow_microbatching=allow_microbatching, allow_dp_padding=allow_dp_padding, @@ -2837,7 +2838,7 @@ def _determine_batch_execution_and_padding( return ( cudagraph_mode, batch_descriptor, - ubatch_slices, + should_ubatch, num_tokens_across_dp, cudagraph_stats, ) @@ -2936,7 +2937,7 @@ def execute_model( ( cudagraph_mode, batch_desc, - ubatch_slices, + should_ubatch, num_tokens_across_dp, cudagraph_stats, ) = self._determine_batch_execution_and_padding( @@ -2949,10 +2950,10 @@ def execute_model( logger.debug( "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " - "ubatch_slices: %s, num_tokens_across_dp: %s", + "should_ubatch: %s, num_tokens_across_dp: %s", cudagraph_mode, batch_desc, - ubatch_slices, + should_ubatch, num_tokens_across_dp, ) @@ -2960,10 +2961,18 @@ def execute_model( num_reqs_padded = ( batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + ) - use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 pad_attn = cudagraph_mode == CUDAGraphMode.FULL + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + (attn_metadata, spec_decode_common_attn_metadata) = ( self._build_attention_metadata( num_tokens=num_tokens_unpadded, @@ -2971,7 +2980,7 @@ def execute_model( num_reqs=num_reqs, num_reqs_padded=num_reqs_padded if pad_attn else None, max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_attn, logits_indices=logits_indices, use_spec_decode=use_spec_decode, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, @@ -3008,7 +3017,7 @@ def execute_model( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded, ), record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, @@ -3961,7 +3970,7 @@ def _dummy_run( num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - _cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = ( + _cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = ( self._determine_batch_execution_and_padding( num_tokens=num_tokens_unpadded, num_reqs=num_reqs, @@ -3995,6 +4004,9 @@ def _dummy_run( num_reqs_padded = ( batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded + ) attn_metadata: PerLayerAttnMetadata | None = None @@ -4016,11 +4028,12 @@ def _dummy_run( self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata( num_tokens=num_tokens_unpadded, num_reqs=num_reqs_padded, max_query_len=max_query_len, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, for_cudagraph_capture=is_graph_capturing, ) @@ -4072,11 +4085,11 @@ def _dummy_run( num_tokens_padded, None, False ) - if ubatch_slices is not None: + if ubatch_slices_padded is not None: # Adjust values to reflect a single ubatch. # TODO(sage,lucas): this is cruft that should be addressed in # the padding refactor. - num_tokens_padded = ubatch_slices[0].num_tokens + num_tokens_padded = ubatch_slices_padded[0].num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded @@ -4089,7 +4102,7 @@ def _dummy_run( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded, ), ): outputs = self.model( diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 33a1921d2d98..44788476fc9c 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -42,9 +42,37 @@ def check_ubatch_thresholds( return num_tokens >= config.dbo_prefill_token_threshold -def create_ubatch_slices( - num_scheduled_tokens: np.ndarray, split_point: int +# This just pads the second ubatch slice out to the total number of tokens +# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding. +def _pad_out_ubatch_slices( + ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int ) -> UBatchSlices: + # TODO(lucas): handle empty second ubatch + padded_second_request_slice = slice( + ubatch_slices[1].request_slice.start, num_reqs_padded + ) + padded_second_token_slice = slice( + ubatch_slices[1].token_slice.start, num_total_tokens + ) + return [ + ubatch_slices[0], + UBatchSlice(padded_second_request_slice, padded_second_token_slice), + ] + + +def maybe_create_ubatch_slices( + should_ubatch: bool, + num_scheduled_tokens: np.ndarray, + num_tokens_padded: int, + num_reqs_padded: int, + split_point: int | None = None, +) -> tuple[UBatchSlices | None, UBatchSlices | None]: + if not should_ubatch: + return None, None + + if split_point is None: + split_point = int(num_tokens_padded) // 2 + # TODO(lucas): Refactor the gpu_model_runner.py so we can pass # in cu_num_tokens directly (i.e. query_start_loc) cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) @@ -67,7 +95,15 @@ def create_ubatch_slices( ) second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1) - return [ + ubatch_slices = [ UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice), ] + + ubatch_slices_padded = _pad_out_ubatch_slices( + ubatch_slices, num_tokens_padded, num_reqs_padded + ) + + assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded + + return ubatch_slices, ubatch_slices_padded