Skip to content

Commit 2111b46

Browse files
SageMooremgoin
andauthored
[Core] Simplify the Dp padding/should ubatch coordination logic (vllm-project#25768)
Signed-off-by: Sage Moore <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
1 parent c50901f commit 2111b46

File tree

10 files changed

+297
-462
lines changed

10 files changed

+297
-462
lines changed

tests/v1/attention/test_attention_splitting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
split_attn_metadata,
1414
split_decodes_and_prefills,
1515
)
16-
from vllm.v1.worker.ubatch_splitting import create_ubatch_slices
16+
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
1717

1818

1919
@pytest.fixture

vllm/config/parallel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ class ParallelConfig:
152152
threshold, microbatching will be used. Otherwise, the request will be
153153
processed in a single batch."""
154154

155+
disable_nccl_for_dp_synchronization: bool = False
156+
"""Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py
157+
to use Gloo instead of NCCL for its all reduce"""
158+
155159
ray_workers_use_nsight: bool = False
156160
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
157161

vllm/engine/arg_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ class EngineArgs:
365365
enable_dbo: bool = ParallelConfig.enable_dbo
366366
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
367367
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
368+
disable_nccl_for_dp_synchronization: bool = (
369+
ParallelConfig.disable_nccl_for_dp_synchronization
370+
)
368371
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
369372
enable_eplb: bool = ParallelConfig.enable_eplb
370373
expert_placement_strategy: ExpertPlacementStrategy = (
@@ -760,6 +763,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
760763
"--dbo-prefill-token-threshold",
761764
**parallel_kwargs["dbo_prefill_token_threshold"],
762765
)
766+
parallel_group.add_argument(
767+
"--disable-nccl-for-dp-synchronization",
768+
**parallel_kwargs["disable_nccl_for_dp_synchronization"],
769+
)
763770
parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
764771
parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
765772
parallel_group.add_argument(
@@ -1437,6 +1444,7 @@ def create_engine_config(
14371444
enable_dbo=self.enable_dbo,
14381445
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
14391446
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1447+
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
14401448
enable_eplb=self.enable_eplb,
14411449
eplb_config=self.eplb_config,
14421450
expert_placement_strategy=self.expert_placement_strategy,

vllm/envs.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@
9595
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
9696
VLLM_SKIP_P2P_CHECK: bool = False
9797
VLLM_DISABLED_KERNELS: list[str] = []
98-
VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False
9998
VLLM_DISABLE_PYNCCL: bool = False
10099
VLLM_USE_V1: bool = True
101100
VLLM_ROCM_USE_AITER: bool = False
@@ -830,12 +829,6 @@ def get_vllm_port() -> Optional[int]:
830829
"VLLM_DISABLED_KERNELS": lambda: []
831830
if "VLLM_DISABLED_KERNELS" not in os.environ
832831
else os.environ["VLLM_DISABLED_KERNELS"].split(","),
833-
# Swaps the all reduce backend that we use to coordinate the DP padding
834-
# information from NCCL to gloo.
835-
"VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION": lambda: (
836-
os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower()
837-
in ("true", "1")
838-
),
839832
# Disable pynccl (using torch.distributed instead)
840833
"VLLM_DISABLE_PYNCCL": lambda: (
841834
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")

vllm/forward_context.py

Lines changed: 9 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
99

1010
import torch
11-
import torch.distributed as dist
1211

1312
import vllm.envs as envs
1413
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
1514
from vllm.logger import init_logger
16-
from vllm.platforms import current_platform
17-
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
15+
from vllm.v1.worker.ubatch_utils import UBatchSlices
1816

1917
if TYPE_CHECKING:
2018
from vllm.attention.backends.abstract import AttentionMetadata
@@ -87,129 +85,22 @@ class DPMetadata:
8785
# NOTE: local_sizes should only be set by the chunked_sizes context manager
8886
local_sizes: Optional[list[int]] = None
8987

90-
@staticmethod
91-
def num_tokens_across_dp(
92-
num_tokens: int, dp_size: int, dp_rank: int
93-
) -> torch.Tensor:
94-
"""
95-
Gather the num_tokens across all DP ranks and return results in a
96-
CPU tensor of size dp_size.
97-
"""
98-
from vllm.distributed.parallel_state import get_dp_group
99-
100-
device = current_platform.device_type
101-
group = get_dp_group().device_group
102-
103-
# Transfering this tensor from GPU to CPU will introduce a GPU sync
104-
# point that could adversely affect performance of vllm with asynch
105-
# scheduling. This environment variable exists to quickly disable
106-
# this optimization if we run into this case.
107-
if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
108-
logger.info_once(
109-
"Using CPU all reduce to syncronize DP padding between ranks."
110-
)
111-
device = "cpu"
112-
group = get_dp_group().cpu_group
113-
num_tokens_across_dp = [0] * dp_size
114-
num_tokens_across_dp[dp_rank] = num_tokens
115-
num_tokens_tensor = torch.tensor(
116-
num_tokens_across_dp, device=device, dtype=torch.int32
117-
)
118-
dist.all_reduce(num_tokens_tensor, group=group)
119-
return num_tokens_tensor.cpu()
120-
121-
# Get the cumulative tokens across sequence parallel ranks.
122-
# In this case the input to the MoEs will be distributed w.r.t both
123-
# DP and TP rank.
124-
# When sp_size==1, this is just the cummulative num tokens across DP.
125-
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
126-
num_tokens_across_sp_cpu = (
127-
self.num_tokens_across_dp_cpu - 1 + sp_size
128-
) // sp_size
129-
num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
130-
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
131-
132-
@staticmethod
133-
def should_ubatch_across_dp(
134-
should_ubatch: bool,
135-
orig_num_tokens_per_ubatch: int,
136-
padded_num_tokens_per_ubatch: int,
137-
dp_size: int,
138-
dp_rank: int,
139-
) -> tuple[bool, Optional[torch.Tensor]]:
140-
"""
141-
1. Decides if each DP rank is going to microbatch. Either all ranks
142-
run with microbatching or none of them do. If this function decides
143-
not to run with microbatching. It will "abort" meaning that no padding
144-
information will be returned to the caller. It will return (False, None)
145-
146-
2. Determines the total number of tokens that each rank will run.
147-
All ranks will be padded out so that the run with the same number
148-
of tokens
149-
150-
Returns: tuple[
151-
should_ubatch: Are all DP ranks going to microbatch
152-
num_tokens_after_padding: A tensor containing the total number of
153-
tokens per-microbatch for each DP rank including padding. Will be
154-
None if should_ubatch if False
155-
]
156-
"""
157-
158-
device = current_platform.device_type
159-
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
160-
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
161-
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
162-
tensor[2][dp_rank] = 1 if should_ubatch else 0
163-
164-
from vllm.distributed.parallel_state import get_dp_group
165-
166-
dist.all_reduce(tensor, group=get_dp_group().device_group)
167-
168-
result: bool = bool(torch.all(tensor[2] == 1).item())
169-
if not result:
170-
return result, None
171-
172-
orig_num_tokens_tensor = tensor[0, :]
173-
padded_num_tokens_tensor = tensor[1, :]
174-
175-
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
176-
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
177-
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
178-
logger.debug(
179-
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
180-
)
181-
return False, None
182-
return result, padded_num_tokens_tensor.cpu()
183-
18488
@staticmethod
18589
def make(
18690
parallel_config: ParallelConfig,
187-
attn_metadata: Any,
18891
num_tokens: int,
189-
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None,
92+
num_tokens_across_dp_cpu: torch.Tensor,
19093
) -> "DPMetadata":
94+
assert num_tokens_across_dp_cpu is not None
19195
assert parallel_config.data_parallel_size > 1
192-
dp_size = parallel_config.data_parallel_size
19396
dp_rank = parallel_config.data_parallel_rank
194-
if attn_metadata is not None and hasattr(attn_metadata, "num_prefill_tokens"):
195-
# for v0 attention backends
196-
batchsize = (
197-
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
198-
)
199-
else:
200-
# for v1 attention backends or no attn_metadata
201-
batchsize = num_tokens
97+
batchsize = num_tokens
20298

20399
# If num_tokens_across_dp is None, it will be computed by all_reduce
204100
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
205-
assert (
206-
num_tokens_across_dp_cpu is None
207-
or num_tokens_across_dp_cpu[dp_rank] == batchsize
208-
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
209-
if num_tokens_across_dp_cpu is None:
210-
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
211-
batchsize, dp_size, dp_rank
212-
)
101+
assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
102+
f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
103+
)
213104
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
214105
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
215106

@@ -376,11 +267,9 @@ def set_forward_context(
376267
if vllm_config.parallel_config.data_parallel_size > 1 and (
377268
attn_metadata is not None or num_tokens is not None
378269
):
270+
assert num_tokens_across_dp is not None
379271
dp_metadata = DPMetadata.make(
380-
vllm_config.parallel_config,
381-
attn_metadata,
382-
num_tokens or 0,
383-
num_tokens_across_dp,
272+
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
384273
)
385274

386275
forward_context = create_forward_context(

0 commit comments

Comments
 (0)