|
8 | 8 | from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union |
9 | 9 |
|
10 | 10 | import torch |
11 | | -import torch.distributed as dist |
12 | 11 |
|
13 | 12 | import vllm.envs as envs |
14 | 13 | from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig |
15 | 14 | from vllm.logger import init_logger |
16 | | -from vllm.platforms import current_platform |
17 | | -from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty |
| 15 | +from vllm.v1.worker.ubatch_utils import UBatchSlices |
18 | 16 |
|
19 | 17 | if TYPE_CHECKING: |
20 | 18 | from vllm.attention.backends.abstract import AttentionMetadata |
@@ -87,129 +85,22 @@ class DPMetadata: |
87 | 85 | # NOTE: local_sizes should only be set by the chunked_sizes context manager |
88 | 86 | local_sizes: Optional[list[int]] = None |
89 | 87 |
|
90 | | - @staticmethod |
91 | | - def num_tokens_across_dp( |
92 | | - num_tokens: int, dp_size: int, dp_rank: int |
93 | | - ) -> torch.Tensor: |
94 | | - """ |
95 | | - Gather the num_tokens across all DP ranks and return results in a |
96 | | - CPU tensor of size dp_size. |
97 | | - """ |
98 | | - from vllm.distributed.parallel_state import get_dp_group |
99 | | - |
100 | | - device = current_platform.device_type |
101 | | - group = get_dp_group().device_group |
102 | | - |
103 | | - # Transfering this tensor from GPU to CPU will introduce a GPU sync |
104 | | - # point that could adversely affect performance of vllm with asynch |
105 | | - # scheduling. This environment variable exists to quickly disable |
106 | | - # this optimization if we run into this case. |
107 | | - if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: |
108 | | - logger.info_once( |
109 | | - "Using CPU all reduce to syncronize DP padding between ranks." |
110 | | - ) |
111 | | - device = "cpu" |
112 | | - group = get_dp_group().cpu_group |
113 | | - num_tokens_across_dp = [0] * dp_size |
114 | | - num_tokens_across_dp[dp_rank] = num_tokens |
115 | | - num_tokens_tensor = torch.tensor( |
116 | | - num_tokens_across_dp, device=device, dtype=torch.int32 |
117 | | - ) |
118 | | - dist.all_reduce(num_tokens_tensor, group=group) |
119 | | - return num_tokens_tensor.cpu() |
120 | | - |
121 | | - # Get the cumulative tokens across sequence parallel ranks. |
122 | | - # In this case the input to the MoEs will be distributed w.r.t both |
123 | | - # DP and TP rank. |
124 | | - # When sp_size==1, this is just the cummulative num tokens across DP. |
125 | | - def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor: |
126 | | - num_tokens_across_sp_cpu = ( |
127 | | - self.num_tokens_across_dp_cpu - 1 + sp_size |
128 | | - ) // sp_size |
129 | | - num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size) |
130 | | - return torch.cumsum(num_tokens_across_sp_cpu, dim=0) |
131 | | - |
132 | | - @staticmethod |
133 | | - def should_ubatch_across_dp( |
134 | | - should_ubatch: bool, |
135 | | - orig_num_tokens_per_ubatch: int, |
136 | | - padded_num_tokens_per_ubatch: int, |
137 | | - dp_size: int, |
138 | | - dp_rank: int, |
139 | | - ) -> tuple[bool, Optional[torch.Tensor]]: |
140 | | - """ |
141 | | - 1. Decides if each DP rank is going to microbatch. Either all ranks |
142 | | - run with microbatching or none of them do. If this function decides |
143 | | - not to run with microbatching. It will "abort" meaning that no padding |
144 | | - information will be returned to the caller. It will return (False, None) |
145 | | -
|
146 | | - 2. Determines the total number of tokens that each rank will run. |
147 | | - All ranks will be padded out so that the run with the same number |
148 | | - of tokens |
149 | | -
|
150 | | - Returns: tuple[ |
151 | | - should_ubatch: Are all DP ranks going to microbatch |
152 | | - num_tokens_after_padding: A tensor containing the total number of |
153 | | - tokens per-microbatch for each DP rank including padding. Will be |
154 | | - None if should_ubatch if False |
155 | | - ] |
156 | | - """ |
157 | | - |
158 | | - device = current_platform.device_type |
159 | | - tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32) |
160 | | - tensor[0][dp_rank] = orig_num_tokens_per_ubatch |
161 | | - tensor[1][dp_rank] = padded_num_tokens_per_ubatch |
162 | | - tensor[2][dp_rank] = 1 if should_ubatch else 0 |
163 | | - |
164 | | - from vllm.distributed.parallel_state import get_dp_group |
165 | | - |
166 | | - dist.all_reduce(tensor, group=get_dp_group().device_group) |
167 | | - |
168 | | - result: bool = bool(torch.all(tensor[2] == 1).item()) |
169 | | - if not result: |
170 | | - return result, None |
171 | | - |
172 | | - orig_num_tokens_tensor = tensor[0, :] |
173 | | - padded_num_tokens_tensor = tensor[1, :] |
174 | | - |
175 | | - orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) |
176 | | - padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) |
177 | | - if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): |
178 | | - logger.debug( |
179 | | - "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens |
180 | | - ) |
181 | | - return False, None |
182 | | - return result, padded_num_tokens_tensor.cpu() |
183 | | - |
184 | 88 | @staticmethod |
185 | 89 | def make( |
186 | 90 | parallel_config: ParallelConfig, |
187 | | - attn_metadata: Any, |
188 | 91 | num_tokens: int, |
189 | | - num_tokens_across_dp_cpu: Optional[torch.Tensor] = None, |
| 92 | + num_tokens_across_dp_cpu: torch.Tensor, |
190 | 93 | ) -> "DPMetadata": |
| 94 | + assert num_tokens_across_dp_cpu is not None |
191 | 95 | assert parallel_config.data_parallel_size > 1 |
192 | | - dp_size = parallel_config.data_parallel_size |
193 | 96 | dp_rank = parallel_config.data_parallel_rank |
194 | | - if attn_metadata is not None and hasattr(attn_metadata, "num_prefill_tokens"): |
195 | | - # for v0 attention backends |
196 | | - batchsize = ( |
197 | | - attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens |
198 | | - ) |
199 | | - else: |
200 | | - # for v1 attention backends or no attn_metadata |
201 | | - batchsize = num_tokens |
| 97 | + batchsize = num_tokens |
202 | 98 |
|
203 | 99 | # If num_tokens_across_dp is None, it will be computed by all_reduce |
204 | 100 | # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize |
205 | | - assert ( |
206 | | - num_tokens_across_dp_cpu is None |
207 | | - or num_tokens_across_dp_cpu[dp_rank] == batchsize |
208 | | - ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" |
209 | | - if num_tokens_across_dp_cpu is None: |
210 | | - num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp( |
211 | | - batchsize, dp_size, dp_rank |
212 | | - ) |
| 101 | + assert num_tokens_across_dp_cpu[dp_rank] == batchsize, ( |
| 102 | + f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" |
| 103 | + ) |
213 | 104 | max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) |
214 | 105 | return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu) |
215 | 106 |
|
@@ -376,11 +267,9 @@ def set_forward_context( |
376 | 267 | if vllm_config.parallel_config.data_parallel_size > 1 and ( |
377 | 268 | attn_metadata is not None or num_tokens is not None |
378 | 269 | ): |
| 270 | + assert num_tokens_across_dp is not None |
379 | 271 | dp_metadata = DPMetadata.make( |
380 | | - vllm_config.parallel_config, |
381 | | - attn_metadata, |
382 | | - num_tokens or 0, |
383 | | - num_tokens_across_dp, |
| 272 | + vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp |
384 | 273 | ) |
385 | 274 |
|
386 | 275 | forward_context = create_forward_context( |
|
0 commit comments