diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index be69075f94f0..4f29cbc01a2f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -17,6 +17,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import pass_context +from vllm.compilation.nanoflow import manager as nano_manager from vllm.compilation.partition_rules import ( inductor_partition_rule_context, should_split, @@ -694,9 +695,14 @@ def __call__( self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or not self.compilation_config.cudagraph_copy_inputs ): - return VllmSerializableFunction( - graph, example_inputs, self.prefix, self.split_gm - ) + if self.compilation_config.enable_nano_batch_split: + return nano_manager.get_callable( + self.split_gm, self.compilation_config, local_cache_dir + ) + else: + return VllmSerializableFunction( + graph, example_inputs, self.prefix, self.split_gm + ) # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode diff --git a/vllm/compilation/nanoflow/__init__.py b/vllm/compilation/nanoflow/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/compilation/nanoflow/manager.py b/vllm/compilation/nanoflow/manager.py new file mode 100644 index 000000000000..a7b80ac20476 --- /dev/null +++ b/vllm/compilation/nanoflow/manager.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import contextlib +import copy +import os +from collections.abc import Callable + +import torch +import torch.fx.graph_module + +from vllm.compilation.nanoflow.split_utils import ( + FakeModule, + NanoOpInfo, + NanoSplitConfig, + analyze_graph, + get_split_config, + split_graph, + tag_graph, +) +from vllm.config import CompilationConfig + + +class NanoSplitManager: + def __init__( + self, + graph_module: torch.fx.GraphModule, + compilation_config: CompilationConfig, + local_cache_dir: str | None, + ) -> None: + self.original_graph_module = graph_module + self.original_graph = graph_module.graph + + # Nano split preparation + self.min_nano_split_tokens = compilation_config.min_nano_split_tokens + self.max_num_nano_batches = compilation_config.max_num_nano_batches + # Initialize the base graph + tag_graph( + self.original_graph_module, + { + "vllm.all_reduce": "all_reduce", + }, + ) + self.graph_modules = {1: self.original_graph_module} + + # Runtime preparation + self.cached_config: NanoSplitConfig | None = None + self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() + self.comp_stream: torch.cuda.Stream = torch.cuda.Stream() + self.hook: ( + Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]] | None + ) = None + self.get_bs_fn = "get_batch_size" + self.split_fn = "split_input" + self.wrapper_fn = "op_wrapper" + setattr(self.original_graph_module, self.get_bs_fn, None) + setattr(self.original_graph_module, self.split_fn, None) + setattr(self.original_graph_module, self.wrapper_fn, None) + + splittable_inputs, base_graph = analyze_graph(self.original_graph) + for num_splits in range(2, self.max_num_nano_batches + 1): + new_graph = copy.deepcopy(base_graph) + split_graph( + self.original_graph, + out=new_graph, + splittable_inputs=splittable_inputs, + num_splits=num_splits, + get_bs_fn=self.get_bs_fn, + split_fn=self.split_fn, + wrapper_fn=self.wrapper_fn, + ) + new_graph_module = torch.fx.GraphModule( + self.original_graph_module, new_graph + ) + for name, _ in self.original_graph_module.named_modules(): + if "." in name or name == "": + continue + torch.fx.graph_module._copy_attr( + self.original_graph_module, new_graph_module, name + ) + self.graph_modules[num_splits] = new_graph_module + if local_cache_dir is not None: + graph_path = os.path.join( + local_cache_dir, f"nano_split_{num_splits}.py" + ) + if not os.path.exists(graph_path): + src = ( + "from __future__ import annotations\nimport torch\n" + + new_graph_module.print_readable(print_output=False) + ) + src = src.replace("", "GraphModule") + with open(graph_path, "w") as f: + f.write(src) + + @staticmethod + def get_batch_size(idx: int, cached_config: NanoSplitConfig): + return cached_config.num_tokens[idx] + + @staticmethod + def split_input(x: torch.Tensor, idx: int, cached_config: NanoSplitConfig): + return x[ + cached_config.split_indices[idx] : cached_config.split_indices[idx + 1] + ] + + @staticmethod + def op_wrapper( + submod_name: str, + idx: int, + args: tuple, + kwargs: dict, + gm: torch.fx.GraphModule, + hooks: list[Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]]], + ): + module = getattr(gm, submod_name) + tag = getattr(module, "tag", "") + with contextlib.ExitStack() as stack: + for hook in hooks: + stack.enter_context( + hook(NanoOpInfo(submod_name, tag, idx, args, kwargs)) + ) + output = module(*args, **kwargs) + return output + + def get_callable(self) -> Callable: + def _forward(*args, **kwargs): + if self.cached_config is None or self.cached_config.num_nano_batches == 1: + return self.original_graph_module(*args, **kwargs) + + num_nano_batches = self.cached_config.num_nano_batches + comm_finished: list[torch.cuda.Event | None] = [ + None for _ in range(num_nano_batches) + ] + comp_finished: list[torch.cuda.Event | None] = [ + None for _ in range(num_nano_batches) + ] + + @contextlib.contextmanager + def set_stream(op_info: NanoOpInfo): + if op_info.tag == "all_reduce": + torch.cuda.set_stream(self.comm_stream) + comm_finished[op_info.idx] = torch.cuda.Event() + if comp_finished[op_info.idx] is not None: + # NOTE(yi): this is to make mypy happy + comp_finished_event = comp_finished[op_info.idx] + assert comp_finished_event is not None + comp_finished_event.wait() + comp_finished[op_info.idx] = None + else: + torch.cuda.set_stream(self.comp_stream) + comp_finished[op_info.idx] = torch.cuda.Event() + if comm_finished[op_info.idx] is not None: + comm_finished_event = comm_finished[op_info.idx] + assert comm_finished_event is not None + comm_finished_event.wait() + comm_finished[op_info.idx] = None + try: + yield + except: + raise + finally: + if op_info.tag == "all_reduce": + comm_finished_event = comm_finished[op_info.idx] + assert comm_finished_event is not None + comm_finished_event.record() + else: + comp_finished_event = comp_finished[op_info.idx] + assert comp_finished_event is not None + comp_finished_event.record() + + @contextlib.contextmanager + def nvtx_mark(op_info: NanoOpInfo): + try: + with torch.cuda.nvtx.range( + f"op_{op_info.submod_name}_{op_info.tag}_{op_info.idx}" + ): + yield + except: + raise + + # Register fake modules + assert self.hook is not None + op_wrapper = FakeModule( + NanoSplitManager.op_wrapper, + gm=self.graph_modules[num_nano_batches], + hooks=[ + set_stream, + nvtx_mark, + self.hook, + ], + ) + get_batch_size = FakeModule( + NanoSplitManager.get_batch_size, + cached_config=self.cached_config, + ) + split_input = FakeModule( + NanoSplitManager.split_input, + cached_config=self.cached_config, + ) + setattr(self.graph_modules[num_nano_batches], self.wrapper_fn, op_wrapper) + setattr( + self.graph_modules[num_nano_batches], self.get_bs_fn, get_batch_size + ) + setattr(self.graph_modules[num_nano_batches], self.split_fn, split_input) + output = self.graph_modules[num_nano_batches](*args, **kwargs) + return output + + return _forward + + def prepare( + self, + batch_size: int, + num_tokens: list[int], + ) -> NanoSplitConfig: + self.cached_config = get_split_config( + batch_size, + num_tokens, + self.max_num_nano_batches, + self.min_nano_split_tokens, + ) + return self.cached_config + + def set_hooks( + self, op_hook: Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]] + ): + self.hook = op_hook + + +_split_manager = None + + +def get_callable( + graph_module: torch.fx.GraphModule, + compilation_config: CompilationConfig, + local_cache_dir: str | None = None, +) -> Callable: + global _split_manager + if _split_manager is None: + _split_manager = NanoSplitManager( + graph_module, compilation_config, local_cache_dir + ) + return _split_manager.get_callable() + + +def prepare_nano_split( + batch_size: int, + num_tokens: list[int], +) -> NanoSplitConfig: + global _split_manager + if _split_manager is None: + raise ValueError("Split manager not initialized") + return _split_manager.prepare(batch_size, num_tokens) + + +def set_op_hook( + op_hook: Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]], +): + global _split_manager + if _split_manager is None: + raise ValueError("Split manager not initialized") + _split_manager.set_hooks(op_hook) diff --git a/vllm/compilation/nanoflow/split_utils.py b/vllm/compilation/nanoflow/split_utils.py new file mode 100644 index 000000000000..39c1da0b1b71 --- /dev/null +++ b/vllm/compilation/nanoflow/split_utils.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +import itertools +from typing import Callable, Union + +import torch +from torch.fx.node import Argument as NodeArgument + + +@dataclasses.dataclass +class NanoOpInfo: + submod_name: str + tag: str + idx: int + args: tuple + kwargs: dict + + +@dataclasses.dataclass +class NanoSplitConfig: + num_nano_batches: int + # Request level information + batch_sizes: list[int] + batch_indices: list[int] + # Token level information + num_tokens: list[int] + split_indices: list[int] # start/end indices of each nano batch + + +class FakeModule(torch.nn.Module): + def __init__(self, fn: Callable, **kwargs): + super().__init__() + self.fn = fn + self.kwargs = kwargs + + def forward(self, *args, **kwargs): + return self.fn(*args, **self.kwargs, **kwargs) + + +def get_split_config( + batch_size: int, + num_tokens: list[int], + max_num_nano_batches: int, + min_nano_split_tokens: int, +) -> NanoSplitConfig: + num_nano_batches = 0 + nano_batch_token_indices = [0] + nano_batch_req_indices = [0] + nano_batch_req_sizes = [] + nano_batch_token_sizes = [] + prefix_sum = [0] + list(itertools.accumulate(num_tokens)) + # Find the mid point of the tokens + mid = min( + range(len(prefix_sum)), + key=lambda i: abs(prefix_sum[i] - (prefix_sum[-1] - prefix_sum[i])), + ) + if ( + prefix_sum[mid] < min_nano_split_tokens + or (prefix_sum[-1] - prefix_sum[mid]) < min_nano_split_tokens + ): + num_nano_batches = 1 + nano_batch_req_indices.append(batch_size) + nano_batch_token_indices.append(prefix_sum[-1]) + nano_batch_req_sizes.append(batch_size) + nano_batch_token_sizes.append(prefix_sum[-1]) + else: + num_nano_batches = 2 + nano_batch_req_indices.extend([mid, batch_size]) + nano_batch_token_indices.extend([prefix_sum[mid], prefix_sum[-1]]) + nano_batch_req_sizes.extend([mid, batch_size - mid]) + nano_batch_token_sizes.extend( + [prefix_sum[mid], prefix_sum[-1] - prefix_sum[mid]] + ) + + return NanoSplitConfig( + num_nano_batches=num_nano_batches, + batch_sizes=nano_batch_req_sizes, + batch_indices=nano_batch_req_indices, + num_tokens=nano_batch_token_sizes, + split_indices=nano_batch_token_indices, + ) + + +def display_graph(graph_module: torch.fx.GraphModule, name: str): + from torch._dynamo.utils import lazy_format_graph_code # type: ignore + + print(lazy_format_graph_code(name, graph_module)) + + +def tag_graph(gm: torch.fx.GraphModule, op_tags: dict[str, str]): + submodules = [ + (name, module) + for (name, module) in gm.named_modules() + if hasattr(module, "graph") + ] + for _, module in submodules: + for node in module.graph.nodes: + if ( + node.op == "call_function" + and (tag := op_tags.get(str(node.target))) is not None + ): + assert getattr(module, "tag", None) is None or module.tag == tag, ( + f"tag mismatch: {module.tag} != {tag}" + ) + module.tag = tag + + +def analyze_graph( + graph: torch.fx.Graph, batch_size: Union[int, torch.SymInt, None] = None +) -> tuple[list[torch.fx.Node], torch.fx.Graph]: + weight_nodes = set() + splittable_inputs = [] + base_graph = torch.fx.Graph() + for node in graph.nodes: + # Skip computation nodes + if node.op != "placeholder": + continue + + # We assume the batch size is the first argument + if batch_size is None: + arg = node.meta["example_value"] + if not isinstance(arg, torch.SymInt): + raise ValueError("Batch size is not set") + batch_size = arg + elif isinstance(input_tensor := node.meta["example_value"], torch.Tensor): + shape = input_tensor.shape + if shape[0] == batch_size: + splittable_inputs.append(node) + else: + weight_nodes.add(node) + # Copy all placeholder nodes to the new graph + base_graph.node_copy(node, arg_transform=lambda n: n) + return splittable_inputs, base_graph + + +def split_graph( + graph: torch.fx.Graph, + *, + out: torch.fx.Graph, + splittable_inputs: list[torch.fx.Node], + num_splits: int, + get_bs_fn: str, + split_fn: str, + wrapper_fn: str, +) -> torch.fx.Graph: + mapping: dict[NodeArgument, list[torch.fx.Node]] = {} + nano_batch_sizes = [] + + # Step 1: Get nano batch sizes and split inputs + for i in range(num_splits): + nano_batch_sizes.append( + out.call_module( + get_bs_fn, + args=(i,), + ) + ) + for node in splittable_inputs: + mapping[node] = [] + for i in range(num_splits): + slice_node = out.call_module( + split_fn, + args=(node, i), + ) + mapping[node].append(slice_node) + + # Step 2: Split computation nodes + def _transform(idx: int, n: NodeArgument) -> NodeArgument: + if n in mapping: + return mapping[n][idx] + if isinstance(getattr(n, "meta", {}).get("example_value", None), torch.SymInt): + return nano_batch_sizes[idx] + return n + + for node in graph.nodes: + if node.op in ["placeholder", "output"]: + continue + splits = [] + for split_idx in range(num_splits): + if node.op == "call_module": + new_args = [_transform(split_idx, arg) for arg in node.args] + new_kwargs = { + k: _transform(split_idx, v) for k, v in node.kwargs.items() + } + new_node = out.call_module( + wrapper_fn, + args=(str(node.target), split_idx, new_args, new_kwargs), + ) + else: + new_node = out.node_copy( + node, arg_transform=lambda n, idx=split_idx: _transform(idx, n) + ) + splits.append(new_node) + mapping[node] = splits + + # Step 3: Concatenate outputs + output_nodes = [node for node in graph.nodes if node.op == "output"] + assert len(output_nodes) == 1, f"Expected 1 output node, found {len(output_nodes)}" + output_node = output_nodes[0] + if not output_node.args: + raise ValueError("Output node has no arguments") + original_outputs = output_node.args[0] + is_tuple = isinstance(original_outputs, tuple) + if not isinstance(original_outputs, tuple): + original_outputs = (original_outputs,) + new_outputs = [] + + for original_output in original_outputs: + if original_output in mapping: + # Get all split outputs + split_outputs = mapping[original_output] + + # Create concatenation node + if len(split_outputs) == 1: + # If there's only one split, no need to concatenate + concat_node = split_outputs[0] + else: + # Create concatenation node + concat_node = out.call_function( + torch.cat, + args=(split_outputs, 0), # Concatenate along first dimension + ) + + new_outputs.append(concat_node) + else: + raise ValueError( + f"Original output {original_output} not found in node_splits" + ) + + out.output(tuple(new_outputs) if is_tuple else new_outputs[0]) + return out diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c84a060922e3..092924b833a7 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -408,6 +408,13 @@ class CompilationConfig: inductor `call` function in the model runner. The top-level full cudagraph capture ignores all partitioning. """ + enable_nano_batch_split: bool = False + """Enable splitting the input batch into nano-batches for intra-device + parallelism""" + max_num_nano_batches: int = 2 + """Maximum number of nano-batches to split the input batch into""" + min_nano_split_tokens: int = 1024 + """Minimum number of tokens to split the input batch""" pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" @@ -494,6 +501,9 @@ def compute_hash(self) -> str: factors.append(self.inductor_passes) factors.append(self.pass_config.uuid()) factors.append(self.compile_cache_save_format) + factors.append(self.enable_nano_batch_split) + factors.append(self.max_num_nano_batches) + factors.append(self.min_nano_split_tokens) return hashlib.sha256(str(factors).encode()).hexdigest() def __repr__(self) -> str: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index d4ee6f980e6e..5a5c1db51dd1 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -414,6 +414,34 @@ def __post_init__(self): "precision for chunked prefill triton kernels." ) + if self.compilation_config.enable_nano_batch_split: + if self.model_config.enforce_eager: + logger.info( + "nano batch split is not supported with " + "enforce_eager. Disabling nano batch split." + ) + self.compilation_config.enable_nano_batch_split = False + elif self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + logger.info( + "nano batch split is currently not supported with " + "cudagraph. Disabling nano batch split." + ) + self.compilation_config.enable_nano_batch_split = False + elif self.compilation_config.full_cuda_graph: + logger.info( + "full_cuda_graph is not supported with " + "nano batch split. Disabling nano batch split." + ) + self.compilation_config.enable_nano_batch_split = False + elif ( + self.compilation_config.splitting_ops + and "vllm.all_reduce" not in self.compilation_config.splitting_ops + ): + logger.info( + "adding vllm.all_reduce to splitting_ops for nano batch split." + ) + self.compilation_config.splitting_ops.append("vllm.all_reduce") + # If the user does not explicitly set a compilation mode, then # we use the default mode. The default mode depends on other # settings (see the below code). diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9403b5756e05..4321cffc370c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -138,6 +138,7 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.worker.nano_batch_split import nano_ubatch_split from vllm.v1.worker.ubatch_utils import ( UBatchSlice, UBatchSlices, @@ -1202,22 +1203,28 @@ def _prepare_inputs( uniform_decode = ( max_num_scheduled_tokens == self.uniform_decode_query_len ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + if self.compilation_config.enable_nano_batch_split: + ubatch_slices, num_tokens_across_dp = nano_ubatch_split( + num_scheduled_tokens, num_tokens_unpadded, num_tokens_padded + ) + else: + # Disable DP padding when running eager to avoid excessive padding when + # running prefills. This lets us set enforce_eager on the prefiller in + # a P/D setup and still use CUDA graphs (enabled by this padding) on the + # decoder. + allow_dp_padding = ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ) - # Disable DP padding when running eager to avoid excessive padding when - # running prefills. This lets us set enforce_eager on the prefiller in - # a P/D setup and still use CUDA graphs (enabled by this padding) on the - # decoder. - allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - - ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( - num_tokens_unpadded=num_tokens_unpadded, - parallel_config=self.parallel_config, - allow_microbatching=True, - allow_dp_padding=allow_dp_padding, - num_tokens_padded=num_tokens_padded, - uniform_decode=uniform_decode, - num_scheduled_tokens_per_request=num_scheduled_tokens, - ) + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens_unpadded, + parallel_config=self.parallel_config, + allow_microbatching=True, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens, + ) self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens @@ -4190,6 +4197,7 @@ def initialize_metadata_builders( else None, num_metadata_builders=1 if not self.parallel_config.enable_dbo + and not self.compilation_config.enable_nano_batch_split else 2, ) # Calculate reorder batch threshold (if needed) diff --git a/vllm/v1/worker/nano_batch_split.py b/vllm/v1/worker/nano_batch_split.py new file mode 100644 index 000000000000..7c042ddb0918 --- /dev/null +++ b/vllm/v1/worker/nano_batch_split.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from typing import Optional + +import numpy as np +import torch + +from vllm.compilation.nanoflow import manager as nano_manager +from vllm.compilation.nanoflow.split_utils import NanoOpInfo +from vllm.forward_context import get_forward_context +from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices + + +def nano_ubatch_split( + num_scheduled_tokens_per_request: np.ndarray, + num_tokens_unpadded: int, + num_tokens_padded: int, +) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]: + """ + Prepare two UBatch-compatible nano-batch slices. + + - Uses nano_manager.prepare_nano_split to decide if splitting is beneficial + (i.e., num_nano_batches > 1). + - Computes a single token split point using custom logic to remain + compatible with UBatch execution. + """ + assert num_tokens_unpadded == num_tokens_padded + batch_size = int(len(num_scheduled_tokens_per_request)) + total_tokens = int(np.sum(num_scheduled_tokens_per_request)) + if batch_size <= 1 or total_tokens <= 1: + return (None, None) + + tokens_list = num_scheduled_tokens_per_request.tolist() + split_config = nano_manager.prepare_nano_split(batch_size, tokens_list) + if getattr(split_config, "num_nano_batches", 1) <= 1: + return (None, None) + assert split_config.num_nano_batches == 2 + + first_slice = UBatchSlice( + slice(0, split_config.batch_indices[1]), slice(0, split_config.split_indices[1]) + ) + second_slice = UBatchSlice( + slice(split_config.batch_indices[1], batch_size), + slice(split_config.split_indices[1], split_config.split_indices[2]), + ) + + @contextmanager + def op_hook(op_info: NanoOpInfo): + ctx = get_forward_context() + attn_metadata_list = ctx.attn_metadata + assert isinstance(attn_metadata_list, list) + ctx.attn_metadata = attn_metadata_list[op_info.idx] + try: + yield + finally: + ctx.attn_metadata = attn_metadata_list + pass + + nano_manager.set_op_hook(op_hook) + + return ( + [first_slice, second_slice], + torch.tensor([num_tokens_padded], device="cpu", dtype=torch.int32), + )