Skip to content

Commit 335aab1

Browse files
committed
move to compilation config
Signed-off-by: Yi Pan <[email protected]>
1 parent 847c6f5 commit 335aab1

File tree

6 files changed

+47
-41
lines changed

6 files changed

+47
-41
lines changed

vllm/compilation/backends.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
596596

597597
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
598598
not self.compilation_config.cudagraph_copy_inputs:
599-
if self.vllm_config.model_config.enable_nano_batch_split:
599+
if self.compilation_config.enable_nano_batch_split:
600600
return nano_manager.get_callable(self.split_gm,
601-
self.vllm_config)
601+
self.compilation_config,
602+
local_cache_dir)
602603
else:
603604
return self.split_gm
604605

vllm/compilation/nanoflow/manager.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import contextlib
55
import copy
6+
import os
67
from typing import Callable, Optional
78

89
import torch
@@ -13,24 +14,23 @@
1314
analyze_graph,
1415
get_split_config,
1516
split_graph, tag_graph)
16-
from vllm.config import VllmConfig
17+
from vllm.config import CompilationConfig
1718

1819

1920
class NanoSplitManager:
2021

2122
def __init__(
2223
self,
2324
graph_module: torch.fx.GraphModule,
24-
vllm_config: VllmConfig,
25+
compilation_config: CompilationConfig,
26+
local_cache_dir: Optional[str],
2527
) -> None:
2628
self.original_graph_module = graph_module
2729
self.original_graph = graph_module.graph
2830

2931
# Nano split preparation
30-
self.min_nano_split_tokens = \
31-
vllm_config.model_config.min_nano_split_tokens
32-
self.max_num_nano_batches = \
33-
vllm_config.model_config.max_num_nano_batches
32+
self.min_nano_split_tokens = compilation_config.min_nano_split_tokens
33+
self.max_num_nano_batches = compilation_config.max_num_nano_batches
3434
# Initialize the base graph
3535
tag_graph(
3636
self.original_graph_module,
@@ -75,6 +75,16 @@ def __init__(
7575
torch.fx.graph_module._copy_attr(self.original_graph_module,
7676
new_graph_module, name)
7777
self.graph_modules[num_splits] = new_graph_module
78+
if local_cache_dir is not None:
79+
graph_path = os.path.join(local_cache_dir,
80+
f"nano_split_{num_splits}.py")
81+
if not os.path.exists(graph_path):
82+
src = (
83+
"from __future__ import annotations\nimport torch\n" +
84+
new_graph_module.print_readable(print_output=False))
85+
src = src.replace("<lambda>", "GraphModule")
86+
with open(graph_path, "w") as f:
87+
f.write(src)
7888

7989
@staticmethod
8090
def get_batch_size(idx: int, cached_config: NanoSplitConfig):
@@ -215,11 +225,15 @@ def set_hooks(self,
215225
_split_manager = None
216226

217227

218-
def get_callable(graph_module: torch.fx.GraphModule,
219-
vllm_config: VllmConfig) -> Callable:
228+
def get_callable(
229+
graph_module: torch.fx.GraphModule,
230+
compilation_config: CompilationConfig,
231+
local_cache_dir: Optional[str] = None,
232+
) -> Callable:
220233
global _split_manager
221234
if _split_manager is None:
222-
_split_manager = NanoSplitManager(graph_module, vllm_config)
235+
_split_manager = NanoSplitManager(graph_module, compilation_config,
236+
local_cache_dir)
223237
return _split_manager.get_callable()
224238

225239

vllm/config/__init__.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -503,13 +503,6 @@ class ModelConfig:
503503
definitions"""
504504
io_processor_plugin: Optional[str] = None
505505
"""IOProcessor plugin name to load at model startup"""
506-
enable_nano_batch_split: bool = False
507-
"""Enable splitting the input batch into nano-batches for intra-device
508-
parallelism"""
509-
max_num_nano_batches: int = 2
510-
"""Maximum number of nano-batches to split the input batch into"""
511-
min_nano_split_tokens: int = 1024
512-
"""Minimum number of tokens to split the input batch"""
513506

514507
def compute_hash(self) -> str:
515508
"""
@@ -538,9 +531,6 @@ def compute_hash(self) -> str:
538531
factors.append(self.override_generation_config)
539532
factors.append(self.rope_scaling)
540533
factors.append(self.rope_theta)
541-
factors.append(self.enable_nano_batch_split)
542-
factors.append(self.max_num_nano_batches)
543-
factors.append(self.min_nano_split_tokens)
544534
# hf_config can control how the model looks!
545535
factors.append(self.hf_config.to_json_string())
546536
str_factors = str(factors)
@@ -3603,25 +3593,27 @@ def __post_init__(self):
36033593
"To workaround this limitation, vLLM will set 'ieee' input "
36043594
"precision for chunked prefill triton kernels.")
36053595

3606-
if self.model_config.enable_nano_batch_split:
3596+
if self.compilation_config.enable_nano_batch_split:
36073597
if self.model_config.enforce_eager:
36083598
logger.info("nano batch split is not supported with "
36093599
"enforce_eager. Disabling nano batch split.")
3610-
self.model_config.enable_nano_batch_split = False
3611-
elif self.compilation_config.use_cudagraph:
3600+
self.compilation_config.enable_nano_batch_split = False
3601+
elif self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
36123602
logger.info("nano batch split is currently not supported with "
36133603
"cudagraph. Disabling nano batch split.")
3614-
self.model_config.enable_nano_batch_split = False
3604+
self.compilation_config.enable_nano_batch_split = False
36153605
elif self.compilation_config.full_cuda_graph:
36163606
logger.info("full_cuda_graph is not supported with "
36173607
"nano batch split. Disabling nano batch split.")
3618-
self.model_config.enable_nano_batch_split = False
3608+
self.compilation_config.enable_nano_batch_split = False
36193609
elif self.compilation_config.splitting_ops:
36203610
logger.info("splitting_ops is not supported with "
36213611
"nano batch split. Disabling nano batch split.")
3622-
self.model_config.enable_nano_batch_split = False
3612+
self.compilation_config.enable_nano_batch_split = False
36233613
else:
3624-
self.compilation_config.splitting_ops = ["vllm.all_reduce"]
3614+
self.compilation_config.splitting_ops = [
3615+
"vllm.all_reduce",
3616+
]
36253617
# If the user does not explicitly set a compilation level, then
36263618
# we use the default level. The default level depends on other
36273619
# settings (see the below code).

vllm/config/compilation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,14 @@ class CompilationConfig:
299299
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
300300
"""
301301

302+
enable_nano_batch_split: bool = False
303+
"""Enable splitting the input batch into nano-batches for intra-device
304+
parallelism"""
305+
max_num_nano_batches: int = 2
306+
"""Maximum number of nano-batches to split the input batch into"""
307+
min_nano_split_tokens: int = 1024
308+
"""Minimum number of tokens to split the input batch"""
309+
302310
pass_config: PassConfig = field(default_factory=PassConfig)
303311
"""Custom inductor passes, see PassConfig for more details"""
304312

@@ -363,6 +371,9 @@ def compute_hash(self) -> str:
363371
factors.append(self.inductor_compile_config)
364372
factors.append(self.inductor_passes)
365373
factors.append(self.pass_config.uuid())
374+
factors.append(self.enable_nano_batch_split)
375+
factors.append(self.max_num_nano_batches)
376+
factors.append(self.min_nano_split_tokens)
366377
return hashlib.sha256(str(factors).encode()).hexdigest()
367378

368379
def __repr__(self) -> str:

vllm/engine/arg_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,6 @@ class EngineArgs:
435435
get_field(ModelConfig, "override_generation_config")
436436
model_impl: str = ModelConfig.model_impl
437437
override_attention_dtype: str = ModelConfig.override_attention_dtype
438-
enable_nano_batch_split: bool = ModelConfig.enable_nano_batch_split
439-
max_num_nano_batches: int = ModelConfig.max_num_nano_batches
440-
min_nano_split_tokens: int = ModelConfig.min_nano_split_tokens
441438

442439
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
443440
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
@@ -583,12 +580,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
583580
**model_kwargs["logits_processors"])
584581
model_group.add_argument("--io-processor-plugin",
585582
**model_kwargs["io_processor_plugin"])
586-
model_group.add_argument("--enable-nano-batch-split",
587-
**model_kwargs["enable_nano_batch_split"])
588-
model_group.add_argument("--max-num-nano-batches",
589-
**model_kwargs["max_num_nano_batches"])
590-
model_group.add_argument("--min-nano-split-tokens",
591-
**model_kwargs["min_nano_split_tokens"])
592583
# Model loading arguments
593584
load_kwargs = get_kwargs(LoadConfig)
594585
load_group = parser.add_argument_group(
@@ -1005,9 +996,6 @@ def create_model_config(self) -> ModelConfig:
1005996
override_attention_dtype=self.override_attention_dtype,
1006997
logits_processors=self.logits_processors,
1007998
io_processor_plugin=self.io_processor_plugin,
1008-
enable_nano_batch_split=self.enable_nano_batch_split,
1009-
max_num_nano_batches=self.max_num_nano_batches,
1010-
min_nano_split_tokens=self.min_nano_split_tokens,
1011999
)
10121000

10131001
def validate_tensorizer_args(self):

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1584,7 +1584,7 @@ def execute_model(
15841584
batch_descriptor=batch_descriptor,
15851585
), self.maybe_get_kv_connector_output(
15861586
scheduler_output) as kv_connector_output:
1587-
if self.vllm_config.model_config.enable_nano_batch_split:
1587+
if self.vllm_config.compilation_config.enable_nano_batch_split:
15881588
self._prepare_nano_split(scheduler_output)
15891589

15901590
model_output = self.model(

0 commit comments

Comments
 (0)