Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tests/v1/attention/test_attention_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
split_attn_metadata,
split_decodes_and_prefills,
)
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices


@pytest.fixture
Expand Down Expand Up @@ -294,8 +294,14 @@ def test_prefill_split_across_ubatches(
qsl_np = common.query_start_loc_cpu.numpy()
num_tokens = common.num_actual_tokens

ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
assert len(ubatch_slices) == 2
ubatch_slices, _ = maybe_create_ubatch_slices(
True,
num_scheduled_tokens,
num_tokens,
batch_spec.batch_size,
split_point=split_point,
)
assert ubatch_slices is not None and len(ubatch_slices) == 2

first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def _pad_batch_across_dp(
num_tokens_padded: int,
) -> tuple[int, torch.Tensor]:
# TODO(Flechman): support DBO ubatching
ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
should_ubatch, num_toks_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
Expand All @@ -1267,7 +1267,7 @@ def _pad_batch_across_dp(
uniform_decode=None,
num_scheduled_tokens_per_request=None,
)
assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"

num_tokens_dp_padded = num_tokens_padded
if num_toks_across_dp is not None:
Expand Down
43 changes: 4 additions & 39 deletions vllm/v1/worker/dp_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import numpy as np
import torch
import torch.distributed as dist
Expand All @@ -9,10 +10,7 @@
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
UBatchSlices,
check_ubatch_thresholds,
create_ubatch_slices,
is_second_ubatch_empty,
)

Expand Down Expand Up @@ -91,20 +89,6 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
return num_tokens_across_dp.cpu()


# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slice(
ubatch_slices: UBatchSlices, num_total_tokens: int
) -> UBatchSlices:
padded_second_token_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
ubatch_slices[1] = UBatchSlice(
ubatch_slices[1].request_slice, padded_second_token_slice
)
return ubatch_slices


def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
Expand Down Expand Up @@ -175,7 +159,7 @@ def coordinate_batch_across_dp(
num_tokens_padded: int | None = None,
uniform_decode: bool | None = None,
num_scheduled_tokens_per_request: np.ndarray | None = None,
) -> tuple[UBatchSlices | None, torch.Tensor | None]:
) -> tuple[bool, torch.Tensor | None]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
Expand Down Expand Up @@ -204,7 +188,7 @@ def coordinate_batch_across_dp(
"""
if parallel_config.data_parallel_size == 1:
# Early exit.
return None, None
return False, None

# If the caller has explicitly enabled microbatching.
should_attempt_ubatching = False
Expand All @@ -228,23 +212,4 @@ def coordinate_batch_across_dp(
parallel_config,
)

# Don't microbatch unless every other DP worker is also microbatching
if not should_ubatch:
return (None, num_tokens_after_padding)

# This doesn't actually pad the ubatch slices. It just initializes the
# split point to the padded value so that padding can be applied
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert num_tokens_after_padding is not None
num_tokens_padded = int(num_tokens_after_padding[0].item())
token_split_point = int(num_tokens_padded) // 2

assert num_scheduled_tokens_per_request is not None
ubatch_slices = create_ubatch_slices(
num_scheduled_tokens_per_request, token_split_point
)
ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded

return (ubatch_slices, num_tokens_after_padding)
return (should_ubatch, num_tokens_after_padding)
Comment on lines 213 to +215

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Update DP coordination call sites for bool return

The new coordinate_batch_across_dp now returns a (bool, Tensor) pair (return (should_ubatch, num_tokens_after_padding)), but vllm/v1/spec_decode/eagle.py::_pad_batch_across_dp still expects the first value to be None when microbatching is disabled and asserts ubatch_slices is None. With data-parallel EAGLE runs this now receives False instead, triggering the assertion before any work is done. The call sites at eagle.py:1261/1270 need to be adjusted to the new return type to avoid failing every DP EAGLE execution.

Useful? React with 👍 / 👎.

45 changes: 29 additions & 16 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
from vllm.v1.worker.ubatch_utils import (
UBatchSlices,
check_ubatch_thresholds,
maybe_create_ubatch_slices,
)
from vllm.v1.worker.utils import is_residual_scattered_for_sp

Expand Down Expand Up @@ -2758,7 +2759,7 @@ def _determine_batch_execution_and_padding(
) -> tuple[
CUDAGraphMode,
BatchDescriptor,
UBatchSlices | None,
bool,
torch.Tensor | None,
CUDAGraphStat | None,
]:
Expand Down Expand Up @@ -2794,7 +2795,7 @@ def _determine_batch_execution_and_padding(

# Extra coordination when running data-parallel since we need to coordinate
# across ranks
ubatch_slices, num_tokens_across_dp = None, None
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
Expand All @@ -2804,8 +2805,8 @@ def _determine_batch_execution_and_padding(
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
)

ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_padded,
should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
Expand Down Expand Up @@ -2837,7 +2838,7 @@ def _determine_batch_execution_and_padding(
return (
cudagraph_mode,
batch_descriptor,
ubatch_slices,
should_ubatch,
num_tokens_across_dp,
cudagraph_stats,
)
Expand Down Expand Up @@ -2936,7 +2937,7 @@ def execute_model(
(
cudagraph_mode,
batch_desc,
ubatch_slices,
should_ubatch,
num_tokens_across_dp,
cudagraph_stats,
) = self._determine_batch_execution_and_padding(
Expand All @@ -2949,29 +2950,37 @@ def execute_model(

logger.debug(
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
"ubatch_slices: %s, num_tokens_across_dp: %s",
"should_ubatch: %s, num_tokens_across_dp: %s",
cudagraph_mode,
batch_desc,
ubatch_slices,
should_ubatch,
num_tokens_across_dp,
)

num_tokens_padded = batch_desc.num_tokens
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch,
num_scheduled_tokens_np,
num_tokens_padded,
num_reqs_padded,
)

use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
pad_attn = cudagraph_mode == CUDAGraphMode.FULL

use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices

(attn_metadata, spec_decode_common_attn_metadata) = (
self._build_attention_metadata(
num_tokens=num_tokens_unpadded,
num_tokens_padded=num_tokens_padded if pad_attn else None,
num_reqs=num_reqs,
num_reqs_padded=num_reqs_padded if pad_attn else None,
max_query_len=max_num_scheduled_tokens,
ubatch_slices=ubatch_slices,
ubatch_slices=ubatch_slices_attn,
logits_indices=logits_indices,
use_spec_decode=use_spec_decode,
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
Expand Down Expand Up @@ -3008,7 +3017,7 @@ def execute_model(
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices,
ubatch_slices=ubatch_slices_padded,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
Expand Down Expand Up @@ -3961,7 +3970,7 @@ def _dummy_run(

num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)

_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = (
_cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = (
self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,
Expand Down Expand Up @@ -3995,6 +4004,9 @@ def _dummy_run(
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded
)

attn_metadata: PerLayerAttnMetadata | None = None

Expand All @@ -4016,11 +4028,12 @@ def _dummy_run(
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
self.query_start_loc.copy_to_gpu()

pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
attn_metadata, _ = self._build_attention_metadata(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs_padded,
max_query_len=max_query_len,
ubatch_slices=ubatch_slices,
ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
for_cudagraph_capture=is_graph_capturing,
)

Expand Down Expand Up @@ -4072,11 +4085,11 @@ def _dummy_run(
num_tokens_padded, None, False
)

if ubatch_slices is not None:
if ubatch_slices_padded is not None:
# Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in
# the padding refactor.
num_tokens_padded = ubatch_slices[0].num_tokens
num_tokens_padded = ubatch_slices_padded[0].num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[:] = num_tokens_padded

Expand All @@ -4089,7 +4102,7 @@ def _dummy_run(
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices,
ubatch_slices=ubatch_slices_padded,
),
):
outputs = self.model(
Expand Down
42 changes: 39 additions & 3 deletions vllm/v1/worker/ubatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,37 @@ def check_ubatch_thresholds(
return num_tokens >= config.dbo_prefill_token_threshold


def create_ubatch_slices(
num_scheduled_tokens: np.ndarray, split_point: int
# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slices(
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
) -> UBatchSlices:
# TODO(lucas): handle empty second ubatch
padded_second_request_slice = slice(
ubatch_slices[1].request_slice.start, num_reqs_padded
)
padded_second_token_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
return [
ubatch_slices[0],
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
]


def maybe_create_ubatch_slices(
should_ubatch: bool,
num_scheduled_tokens: np.ndarray,
num_tokens_padded: int,
num_reqs_padded: int,
split_point: int | None = None,
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
if not should_ubatch:
return None, None

if split_point is None:
split_point = int(num_tokens_padded) // 2

# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
Expand All @@ -67,7 +95,15 @@ def create_ubatch_slices(
)
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)

return [
ubatch_slices = [
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
]

ubatch_slices_padded = _pad_out_ubatch_slices(
ubatch_slices, num_tokens_padded, num_reqs_padded
)

assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded

return ubatch_slices, ubatch_slices_padded