-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[Core] Nanoflow-style Computation-Communication Overlap #23592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Conless
wants to merge
31
commits into
vllm-project:main
Choose a base branch
from
Conless:dev
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+628
−18
Open
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
a8a9f38
feat: basic support of nano split.
Conless ad5fbcd
finish compute comm overlap
Conless 8712a26
update
Conless c240236
fix cpu overhead
Conless 49269f2
update model runner
Conless 39b878b
refine interface
Conless 4970a80
refine
Conless d00c4af
separate nanoflow logic
Conless 9930604
implement auto-on/off logic and fix attn metadata split
Conless 6517920
update
Conless d638a0a
clean the impl
Conless 7bf426a
fix config
Conless a575716
update min split tokens
Conless 6453b8c
format
Conless 8f47d70
Merge remote-tracking branch 'upstream/main' into dev
Conless 483a727
make mypy happy
Conless 21ec47b
Merge remote-tracking branch 'upstream/main' into dev
Conless 847c6f5
Merge remote-tracking branch 'upstream/main' into dev
Conless 335aab1
move to compilation config
Conless f3b2dbe
Merge remote-tracking branch 'upstream/main' into dev
Conless 563fe72
minor
Conless 92903d9
Merge remote-tracking branch 'upstream/main' into dev
Conless 1a26d31
Merge remote-tracking branch 'upstream/main' into dev
Conless ba26309
adapt to dbo design
Conless d1e134e
Merge remote-tracking branch 'upstream/main' into dev
Conless f087b99
fix
Conless 28288f7
Merge remote-tracking branch 'upstream/main' into dev
Conless eed29c1
fix num tokens across dp
Conless d8b3573
format fix
Conless 280cffb
Merge remote-tracking branch 'upstream/main' into HEAD
Conless fc3d688
fix splitting ops
Conless File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,260 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import contextlib | ||
| import copy | ||
| import os | ||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
| import torch.fx.graph_module | ||
|
|
||
| from vllm.compilation.nanoflow.split_utils import ( | ||
| FakeModule, | ||
| NanoOpInfo, | ||
| NanoSplitConfig, | ||
| analyze_graph, | ||
| get_split_config, | ||
| split_graph, | ||
| tag_graph, | ||
| ) | ||
| from vllm.config import CompilationConfig | ||
|
|
||
|
|
||
| class NanoSplitManager: | ||
| def __init__( | ||
| self, | ||
| graph_module: torch.fx.GraphModule, | ||
| compilation_config: CompilationConfig, | ||
| local_cache_dir: str | None, | ||
| ) -> None: | ||
| self.original_graph_module = graph_module | ||
| self.original_graph = graph_module.graph | ||
|
|
||
| # Nano split preparation | ||
| self.min_nano_split_tokens = compilation_config.min_nano_split_tokens | ||
| self.max_num_nano_batches = compilation_config.max_num_nano_batches | ||
| # Initialize the base graph | ||
| tag_graph( | ||
| self.original_graph_module, | ||
| { | ||
| "vllm.all_reduce": "all_reduce", | ||
| }, | ||
| ) | ||
| self.graph_modules = {1: self.original_graph_module} | ||
|
|
||
| # Runtime preparation | ||
| self.cached_config: NanoSplitConfig | None = None | ||
| self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() | ||
| self.comp_stream: torch.cuda.Stream = torch.cuda.Stream() | ||
| self.hook: ( | ||
| Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]] | None | ||
| ) = None | ||
| self.get_bs_fn = "get_batch_size" | ||
| self.split_fn = "split_input" | ||
| self.wrapper_fn = "op_wrapper" | ||
| setattr(self.original_graph_module, self.get_bs_fn, None) | ||
| setattr(self.original_graph_module, self.split_fn, None) | ||
| setattr(self.original_graph_module, self.wrapper_fn, None) | ||
|
|
||
| splittable_inputs, base_graph = analyze_graph(self.original_graph) | ||
| for num_splits in range(2, self.max_num_nano_batches + 1): | ||
| new_graph = copy.deepcopy(base_graph) | ||
| split_graph( | ||
| self.original_graph, | ||
| out=new_graph, | ||
| splittable_inputs=splittable_inputs, | ||
| num_splits=num_splits, | ||
| get_bs_fn=self.get_bs_fn, | ||
| split_fn=self.split_fn, | ||
| wrapper_fn=self.wrapper_fn, | ||
| ) | ||
| new_graph_module = torch.fx.GraphModule( | ||
| self.original_graph_module, new_graph | ||
| ) | ||
| for name, _ in self.original_graph_module.named_modules(): | ||
| if "." in name or name == "": | ||
| continue | ||
| torch.fx.graph_module._copy_attr( | ||
| self.original_graph_module, new_graph_module, name | ||
| ) | ||
| self.graph_modules[num_splits] = new_graph_module | ||
| if local_cache_dir is not None: | ||
| graph_path = os.path.join( | ||
| local_cache_dir, f"nano_split_{num_splits}.py" | ||
| ) | ||
| if not os.path.exists(graph_path): | ||
| src = ( | ||
| "from __future__ import annotations\nimport torch\n" | ||
| + new_graph_module.print_readable(print_output=False) | ||
| ) | ||
| src = src.replace("<lambda>", "GraphModule") | ||
| with open(graph_path, "w") as f: | ||
| f.write(src) | ||
|
|
||
| @staticmethod | ||
| def get_batch_size(idx: int, cached_config: NanoSplitConfig): | ||
| return cached_config.num_tokens[idx] | ||
|
|
||
| @staticmethod | ||
| def split_input(x: torch.Tensor, idx: int, cached_config: NanoSplitConfig): | ||
| return x[ | ||
| cached_config.split_indices[idx] : cached_config.split_indices[idx + 1] | ||
| ] | ||
|
|
||
| @staticmethod | ||
| def op_wrapper( | ||
| submod_name: str, | ||
| idx: int, | ||
| args: tuple, | ||
| kwargs: dict, | ||
| gm: torch.fx.GraphModule, | ||
| hooks: list[Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]]], | ||
| ): | ||
| module = getattr(gm, submod_name) | ||
| tag = getattr(module, "tag", "") | ||
| with contextlib.ExitStack() as stack: | ||
| for hook in hooks: | ||
| stack.enter_context( | ||
| hook(NanoOpInfo(submod_name, tag, idx, args, kwargs)) | ||
| ) | ||
| output = module(*args, **kwargs) | ||
| return output | ||
|
|
||
| def get_callable(self) -> Callable: | ||
| def _forward(*args, **kwargs): | ||
| if self.cached_config is None or self.cached_config.num_nano_batches == 1: | ||
| return self.original_graph_module(*args, **kwargs) | ||
|
|
||
| num_nano_batches = self.cached_config.num_nano_batches | ||
| comm_finished: list[torch.cuda.Event | None] = [ | ||
| None for _ in range(num_nano_batches) | ||
| ] | ||
| comp_finished: list[torch.cuda.Event | None] = [ | ||
| None for _ in range(num_nano_batches) | ||
| ] | ||
|
|
||
| @contextlib.contextmanager | ||
| def set_stream(op_info: NanoOpInfo): | ||
| if op_info.tag == "all_reduce": | ||
| torch.cuda.set_stream(self.comm_stream) | ||
| comm_finished[op_info.idx] = torch.cuda.Event() | ||
| if comp_finished[op_info.idx] is not None: | ||
| # NOTE(yi): this is to make mypy happy | ||
| comp_finished_event = comp_finished[op_info.idx] | ||
| assert comp_finished_event is not None | ||
| comp_finished_event.wait() | ||
| comp_finished[op_info.idx] = None | ||
| else: | ||
| torch.cuda.set_stream(self.comp_stream) | ||
| comp_finished[op_info.idx] = torch.cuda.Event() | ||
| if comm_finished[op_info.idx] is not None: | ||
| comm_finished_event = comm_finished[op_info.idx] | ||
| assert comm_finished_event is not None | ||
| comm_finished_event.wait() | ||
| comm_finished[op_info.idx] = None | ||
| try: | ||
| yield | ||
| except: | ||
| raise | ||
| finally: | ||
| if op_info.tag == "all_reduce": | ||
| comm_finished_event = comm_finished[op_info.idx] | ||
| assert comm_finished_event is not None | ||
| comm_finished_event.record() | ||
| else: | ||
| comp_finished_event = comp_finished[op_info.idx] | ||
| assert comp_finished_event is not None | ||
| comp_finished_event.record() | ||
|
|
||
| @contextlib.contextmanager | ||
| def nvtx_mark(op_info: NanoOpInfo): | ||
| try: | ||
| with torch.cuda.nvtx.range( | ||
| f"op_{op_info.submod_name}_{op_info.tag}_{op_info.idx}" | ||
| ): | ||
| yield | ||
| except: | ||
| raise | ||
|
|
||
| # Register fake modules | ||
| assert self.hook is not None | ||
| op_wrapper = FakeModule( | ||
| NanoSplitManager.op_wrapper, | ||
| gm=self.graph_modules[num_nano_batches], | ||
| hooks=[ | ||
| set_stream, | ||
| nvtx_mark, | ||
| self.hook, | ||
| ], | ||
| ) | ||
| get_batch_size = FakeModule( | ||
| NanoSplitManager.get_batch_size, | ||
| cached_config=self.cached_config, | ||
| ) | ||
| split_input = FakeModule( | ||
| NanoSplitManager.split_input, | ||
| cached_config=self.cached_config, | ||
| ) | ||
| setattr(self.graph_modules[num_nano_batches], self.wrapper_fn, op_wrapper) | ||
| setattr( | ||
| self.graph_modules[num_nano_batches], self.get_bs_fn, get_batch_size | ||
| ) | ||
| setattr(self.graph_modules[num_nano_batches], self.split_fn, split_input) | ||
| output = self.graph_modules[num_nano_batches](*args, **kwargs) | ||
| return output | ||
|
|
||
| return _forward | ||
|
|
||
| def prepare( | ||
| self, | ||
| batch_size: int, | ||
| num_tokens: list[int], | ||
| ) -> NanoSplitConfig: | ||
| self.cached_config = get_split_config( | ||
| batch_size, | ||
| num_tokens, | ||
| self.max_num_nano_batches, | ||
| self.min_nano_split_tokens, | ||
| ) | ||
| return self.cached_config | ||
|
|
||
| def set_hooks( | ||
| self, op_hook: Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]] | ||
| ): | ||
| self.hook = op_hook | ||
|
|
||
|
|
||
| _split_manager = None | ||
|
|
||
|
|
||
| def get_callable( | ||
| graph_module: torch.fx.GraphModule, | ||
| compilation_config: CompilationConfig, | ||
| local_cache_dir: str | None = None, | ||
| ) -> Callable: | ||
| global _split_manager | ||
| if _split_manager is None: | ||
| _split_manager = NanoSplitManager( | ||
| graph_module, compilation_config, local_cache_dir | ||
| ) | ||
| return _split_manager.get_callable() | ||
|
|
||
|
|
||
| def prepare_nano_split( | ||
| batch_size: int, | ||
| num_tokens: list[int], | ||
| ) -> NanoSplitConfig: | ||
| global _split_manager | ||
| if _split_manager is None: | ||
| raise ValueError("Split manager not initialized") | ||
| return _split_manager.prepare(batch_size, num_tokens) | ||
|
|
||
|
|
||
| def set_op_hook( | ||
| op_hook: Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]], | ||
| ): | ||
| global _split_manager | ||
| if _split_manager is None: | ||
| raise ValueError("Split manager not initialized") | ||
| _split_manager.set_hooks(op_hook) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain a bit more on what exactly "nanoflow" does? My understanding is that you split the input batch (of tokens) into smaller batches. What exactly is "duplicating selected operations to overlap compute and communication operations." ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found an example over at https://colab.research.google.com/drive/1zpoptkLA0UiW8ZxnBnrVj2MzEpPectQL?authuser=1 , which helps. Some concrete questions:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assignment of operations to streams depends on which resource each operation is bottlenecked by. For example, GEMM operations are assigned to the compute stream, decode attention to the memory stream, and all-reduce to the network stream. This is implemented by matching operator names (e.g.,
"vllm.all_reduce"gets mapped to the network stream).For cases with more nano-batches, the approach is similar: the input batch is divided into nano-batches, and operators for each nano-batch are launched on streams according to their type (compute, memory, network). Sequential dependencies within each nano-batch are maintained using event synchronization to ensure the correct execution order.