Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a8a9f38
feat: basic support of nano split.
Conless Jul 22, 2025
ad5fbcd
finish compute comm overlap
Conless Jul 24, 2025
8712a26
update
Conless Jul 28, 2025
c240236
fix cpu overhead
Conless Aug 1, 2025
49269f2
update model runner
Conless Aug 1, 2025
39b878b
refine interface
Conless Aug 3, 2025
4970a80
refine
Conless Aug 3, 2025
d00c4af
separate nanoflow logic
Conless Aug 4, 2025
9930604
implement auto-on/off logic and fix attn metadata split
Conless Aug 7, 2025
6517920
update
Conless Aug 25, 2025
d638a0a
clean the impl
Conless Aug 25, 2025
7bf426a
fix config
Conless Aug 25, 2025
a575716
update min split tokens
Conless Aug 25, 2025
6453b8c
format
Conless Aug 25, 2025
8f47d70
Merge remote-tracking branch 'upstream/main' into dev
Conless Aug 25, 2025
483a727
make mypy happy
Conless Aug 27, 2025
21ec47b
Merge remote-tracking branch 'upstream/main' into dev
Conless Sep 1, 2025
847c6f5
Merge remote-tracking branch 'upstream/main' into dev
Conless Sep 1, 2025
335aab1
move to compilation config
Conless Sep 3, 2025
f3b2dbe
Merge remote-tracking branch 'upstream/main' into dev
Conless Sep 3, 2025
563fe72
minor
Conless Sep 21, 2025
92903d9
Merge remote-tracking branch 'upstream/main' into dev
Conless Sep 22, 2025
1a26d31
Merge remote-tracking branch 'upstream/main' into dev
Conless Sep 29, 2025
ba26309
adapt to dbo design
Conless Oct 2, 2025
d1e134e
Merge remote-tracking branch 'upstream/main' into dev
Conless Oct 2, 2025
f087b99
fix
Conless Oct 3, 2025
28288f7
Merge remote-tracking branch 'upstream/main' into dev
Conless Oct 10, 2025
eed29c1
fix num tokens across dp
Conless Oct 10, 2025
d8b3573
format fix
Conless Oct 10, 2025
280cffb
Merge remote-tracking branch 'upstream/main' into HEAD
Conless Nov 11, 2025
fc3d688
fix splitting ops
Conless Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import vllm.envs as envs
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.nanoflow import manager as nano_manager
from vllm.compilation.partition_rules import (
inductor_partition_rule_context,
should_split,
Expand Down Expand Up @@ -694,9 +695,14 @@
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
or not self.compilation_config.cudagraph_copy_inputs
):
return VllmSerializableFunction(
graph, example_inputs, self.prefix, self.split_gm
)
if self.compilation_config.enable_nano_batch_split:
return nano_manager.get_callable(

Check failure on line 699 in vllm/compilation/backends.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible return value type (got "Callable[..., Any]", expected "VllmSerializableFunction") [return-value]

Check failure on line 699 in vllm/compilation/backends.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible return value type (got "Callable[..., Any]", expected "VllmSerializableFunction") [return-value]

Check failure on line 699 in vllm/compilation/backends.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible return value type (got "Callable[..., Any]", expected "VllmSerializableFunction") [return-value]

Check failure on line 699 in vllm/compilation/backends.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible return value type (got "Callable[..., Any]", expected "VllmSerializableFunction") [return-value]
self.split_gm, self.compilation_config, local_cache_dir
)
else:
return VllmSerializableFunction(
graph, example_inputs, self.prefix, self.split_gm
)

# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
Expand Down
Empty file.
260 changes: 260 additions & 0 deletions vllm/compilation/nanoflow/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
import copy
import os
from collections.abc import Callable

import torch
import torch.fx.graph_module

from vllm.compilation.nanoflow.split_utils import (
FakeModule,
NanoOpInfo,
NanoSplitConfig,
analyze_graph,
get_split_config,
split_graph,
tag_graph,
)
from vllm.config import CompilationConfig


class NanoSplitManager:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integration operates at the Torch FX graph level by partitioning input batches into nano-batches and duplicating selected operations to overlap compute and communication operations.

Could you explain a bit more on what exactly "nanoflow" does? My understanding is that you split the input batch (of tokens) into smaller batches. What exactly is "duplicating selected operations to overlap compute and communication operations." ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found an example over at https://colab.research.google.com/drive/1zpoptkLA0UiW8ZxnBnrVj2MzEpPectQL?authuser=1 , which helps. Some concrete questions:

  1. how do you decide what operations to put on which stream?
  2. what happens in the case of multiple batches? The example only has two.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you decide what operations to put on which stream?

The assignment of operations to streams depends on which resource each operation is bottlenecked by. For example, GEMM operations are assigned to the compute stream, decode attention to the memory stream, and all-reduce to the network stream. This is implemented by matching operator names (e.g., "vllm.all_reduce" gets mapped to the network stream).

what happens in the case of multiple batches? The example only has two.

For cases with more nano-batches, the approach is similar: the input batch is divided into nano-batches, and operators for each nano-batch are launched on streams according to their type (compute, memory, network). Sequential dependencies within each nano-batch are maintained using event synchronization to ensure the correct execution order.

def __init__(
self,
graph_module: torch.fx.GraphModule,
compilation_config: CompilationConfig,
local_cache_dir: str | None,
) -> None:
self.original_graph_module = graph_module
self.original_graph = graph_module.graph

# Nano split preparation
self.min_nano_split_tokens = compilation_config.min_nano_split_tokens
self.max_num_nano_batches = compilation_config.max_num_nano_batches
# Initialize the base graph
tag_graph(
self.original_graph_module,
{
"vllm.all_reduce": "all_reduce",
},
)
self.graph_modules = {1: self.original_graph_module}

# Runtime preparation
self.cached_config: NanoSplitConfig | None = None
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.comp_stream: torch.cuda.Stream = torch.cuda.Stream()
self.hook: (
Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]] | None
) = None
self.get_bs_fn = "get_batch_size"
self.split_fn = "split_input"
self.wrapper_fn = "op_wrapper"
setattr(self.original_graph_module, self.get_bs_fn, None)
setattr(self.original_graph_module, self.split_fn, None)
setattr(self.original_graph_module, self.wrapper_fn, None)

splittable_inputs, base_graph = analyze_graph(self.original_graph)
for num_splits in range(2, self.max_num_nano_batches + 1):
new_graph = copy.deepcopy(base_graph)
split_graph(
self.original_graph,
out=new_graph,
splittable_inputs=splittable_inputs,
num_splits=num_splits,
get_bs_fn=self.get_bs_fn,
split_fn=self.split_fn,
wrapper_fn=self.wrapper_fn,
)
new_graph_module = torch.fx.GraphModule(
self.original_graph_module, new_graph
)
for name, _ in self.original_graph_module.named_modules():
if "." in name or name == "":
continue
torch.fx.graph_module._copy_attr(
self.original_graph_module, new_graph_module, name
)
self.graph_modules[num_splits] = new_graph_module
if local_cache_dir is not None:
graph_path = os.path.join(
local_cache_dir, f"nano_split_{num_splits}.py"
)
if not os.path.exists(graph_path):
src = (
"from __future__ import annotations\nimport torch\n"
+ new_graph_module.print_readable(print_output=False)
)
src = src.replace("<lambda>", "GraphModule")
with open(graph_path, "w") as f:
f.write(src)

@staticmethod
def get_batch_size(idx: int, cached_config: NanoSplitConfig):
return cached_config.num_tokens[idx]

@staticmethod
def split_input(x: torch.Tensor, idx: int, cached_config: NanoSplitConfig):
return x[
cached_config.split_indices[idx] : cached_config.split_indices[idx + 1]
]

@staticmethod
def op_wrapper(
submod_name: str,
idx: int,
args: tuple,
kwargs: dict,
gm: torch.fx.GraphModule,
hooks: list[Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]]],
):
module = getattr(gm, submod_name)
tag = getattr(module, "tag", "")
with contextlib.ExitStack() as stack:
for hook in hooks:
stack.enter_context(
hook(NanoOpInfo(submod_name, tag, idx, args, kwargs))
)
output = module(*args, **kwargs)
return output

def get_callable(self) -> Callable:
def _forward(*args, **kwargs):
if self.cached_config is None or self.cached_config.num_nano_batches == 1:
return self.original_graph_module(*args, **kwargs)

num_nano_batches = self.cached_config.num_nano_batches
comm_finished: list[torch.cuda.Event | None] = [
None for _ in range(num_nano_batches)
]
comp_finished: list[torch.cuda.Event | None] = [
None for _ in range(num_nano_batches)
]

@contextlib.contextmanager
def set_stream(op_info: NanoOpInfo):
if op_info.tag == "all_reduce":
torch.cuda.set_stream(self.comm_stream)
comm_finished[op_info.idx] = torch.cuda.Event()
if comp_finished[op_info.idx] is not None:
# NOTE(yi): this is to make mypy happy
comp_finished_event = comp_finished[op_info.idx]
assert comp_finished_event is not None
comp_finished_event.wait()
comp_finished[op_info.idx] = None
else:
torch.cuda.set_stream(self.comp_stream)
comp_finished[op_info.idx] = torch.cuda.Event()
if comm_finished[op_info.idx] is not None:
comm_finished_event = comm_finished[op_info.idx]
assert comm_finished_event is not None
comm_finished_event.wait()
comm_finished[op_info.idx] = None
try:
yield
except:
raise
finally:
if op_info.tag == "all_reduce":
comm_finished_event = comm_finished[op_info.idx]
assert comm_finished_event is not None
comm_finished_event.record()
else:
comp_finished_event = comp_finished[op_info.idx]
assert comp_finished_event is not None
comp_finished_event.record()

@contextlib.contextmanager
def nvtx_mark(op_info: NanoOpInfo):
try:
with torch.cuda.nvtx.range(
f"op_{op_info.submod_name}_{op_info.tag}_{op_info.idx}"
):
yield
except:
raise

# Register fake modules
assert self.hook is not None
op_wrapper = FakeModule(
NanoSplitManager.op_wrapper,
gm=self.graph_modules[num_nano_batches],
hooks=[
set_stream,
nvtx_mark,
self.hook,
],
)
get_batch_size = FakeModule(
NanoSplitManager.get_batch_size,
cached_config=self.cached_config,
)
split_input = FakeModule(
NanoSplitManager.split_input,
cached_config=self.cached_config,
)
setattr(self.graph_modules[num_nano_batches], self.wrapper_fn, op_wrapper)
setattr(
self.graph_modules[num_nano_batches], self.get_bs_fn, get_batch_size
)
setattr(self.graph_modules[num_nano_batches], self.split_fn, split_input)
output = self.graph_modules[num_nano_batches](*args, **kwargs)
return output

return _forward

def prepare(
self,
batch_size: int,
num_tokens: list[int],
) -> NanoSplitConfig:
self.cached_config = get_split_config(
batch_size,
num_tokens,
self.max_num_nano_batches,
self.min_nano_split_tokens,
)
return self.cached_config

def set_hooks(
self, op_hook: Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]]
):
self.hook = op_hook


_split_manager = None


def get_callable(
graph_module: torch.fx.GraphModule,
compilation_config: CompilationConfig,
local_cache_dir: str | None = None,
) -> Callable:
global _split_manager
if _split_manager is None:
_split_manager = NanoSplitManager(
graph_module, compilation_config, local_cache_dir
)
return _split_manager.get_callable()


def prepare_nano_split(
batch_size: int,
num_tokens: list[int],
) -> NanoSplitConfig:
global _split_manager
if _split_manager is None:
raise ValueError("Split manager not initialized")
return _split_manager.prepare(batch_size, num_tokens)


def set_op_hook(
op_hook: Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]],
):
global _split_manager
if _split_manager is None:
raise ValueError("Split manager not initialized")
_split_manager.set_hooks(op_hook)
Loading
Loading