Skip to content

Commit 1424344

Browse files
njhillRonald1995robertgshaw2-redhat
authored andcommitted
[Core] Support async scheduling with uniproc executor (vllm-project#24219)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Ronald1995 <[email protected]> Co-authored-by: Ronald1995 <[email protected]> Co-authored-by: Robert Shaw <[email protected]>
1 parent c19dcee commit 1424344

File tree

9 files changed

+103
-55
lines changed

9 files changed

+103
-55
lines changed

tests/basic_correctness/test_basic_correctness.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def _fix_prompt_embed_outputs(
6262
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
6363
@pytest.mark.parametrize("max_tokens", [5])
6464
@pytest.mark.parametrize("enforce_eager", [False])
65+
@pytest.mark.parametrize("async_scheduling", [True, False])
66+
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
6567
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
6668
def test_models(
6769
monkeypatch: pytest.MonkeyPatch,
@@ -70,13 +72,21 @@ def test_models(
7072
backend: str,
7173
max_tokens: int,
7274
enforce_eager: bool,
75+
async_scheduling: bool,
76+
model_executor: str,
7377
enable_prompt_embeds: bool,
7478
) -> None:
7579

7680
if enable_prompt_embeds and envs.is_set(
7781
"VLLM_USE_V1") and envs.VLLM_USE_V1:
7882
pytest.skip("enable_prompt_embeds is not supported in v1.")
7983

84+
if not envs.VLLM_USE_V1:
85+
if async_scheduling:
86+
pytest.skip("async_scheduling only supported in v1.")
87+
if model_executor != "uni":
88+
pytest.skip("only test uniproc executor for v0.")
89+
8090
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
8191
pytest.skip(
8292
f"{backend} does not support gemma2 with full context length.")
@@ -98,11 +108,15 @@ def test_models(
98108
prompt_embeds = hf_model.get_prompt_embeddings(
99109
example_prompts)
100110

101-
with VllmRunner(model,
102-
max_model_len=8192,
103-
enforce_eager=enforce_eager,
104-
enable_prompt_embeds=enable_prompt_embeds,
105-
gpu_memory_utilization=0.7) as vllm_model:
111+
with VllmRunner(
112+
model,
113+
max_model_len=8192,
114+
enforce_eager=enforce_eager,
115+
enable_prompt_embeds=enable_prompt_embeds,
116+
gpu_memory_utilization=0.7,
117+
async_scheduling=async_scheduling,
118+
distributed_executor_backend=model_executor,
119+
) as vllm_model:
106120
if enable_prompt_embeds:
107121
vllm_outputs = vllm_model.generate_greedy(
108122
prompt_embeds, max_tokens)

tests/v1/engine/test_engine_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,13 @@ def initialize_from_config(
257257
def execute_model(
258258
self,
259259
scheduler_output,
260+
non_block=False,
260261
) -> Future[ModelRunnerOutput]:
261262
"""Make execute_model non-blocking."""
262263

264+
# DummyExecutor used only for testing async case.
265+
assert non_block
266+
263267
def _execute():
264268
output = self.collective_rpc("execute_model",
265269
args=(scheduler_output, ))

vllm/engine/arg_utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,11 +1296,8 @@ def create_engine_config(
12961296
# Async scheduling does not work with the uniprocess backend.
12971297
if self.distributed_executor_backend is None:
12981298
self.distributed_executor_backend = "mp"
1299-
logger.info("Using mp-based distributed executor backend "
1300-
"for async scheduling.")
1301-
if self.distributed_executor_backend == "uni":
1302-
raise ValueError("Async scheduling is not supported with "
1303-
"uni-process backend.")
1299+
logger.info("Defaulting to mp-based distributed executor "
1300+
"backend for async scheduling.")
13041301
if self.pipeline_parallel_size > 1:
13051302
raise ValueError("Async scheduling is not supported with "
13061303
"pipeline-parallel-size > 1.")

vllm/executor/uniproc_executor.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
43
import os
4+
from concurrent.futures import Future, ThreadPoolExecutor
5+
from functools import cached_property
56
from multiprocessing import Lock
67
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
78

@@ -17,6 +18,7 @@
1718
run_method)
1819
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
1920
from vllm.v1.executor.utils import get_and_update_mm_cache
21+
from vllm.v1.outputs import AsyncModelRunnerOutput
2022
from vllm.worker.worker_base import WorkerWrapperBase
2123

2224
logger = init_logger(__name__)
@@ -31,15 +33,7 @@ def _init_executor(self) -> None:
3133
"""
3234
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
3335
rpc_rank=0)
34-
distributed_init_method = get_distributed_init_method(
35-
get_ip(), get_open_port())
36-
local_rank = 0
37-
# set local rank as the device index if specified
38-
device_info = self.vllm_config.device_config.device.__str__().split(
39-
":")
40-
if len(device_info) > 1:
41-
local_rank = int(device_info[1])
42-
rank = 0
36+
distributed_init_method, rank, local_rank = self._distributed_args()
4337
is_driver_worker = True
4438
kwargs = dict(
4539
vllm_config=self.vllm_config,
@@ -50,21 +44,56 @@ def _init_executor(self) -> None:
5044
)
5145
self.mm_receiver_cache = worker_receiver_cache_from_config(
5246
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
47+
48+
self.async_output_thread: Optional[ThreadPoolExecutor] = None
49+
if self.max_concurrent_batches > 1:
50+
self.async_output_thread = ThreadPoolExecutor(
51+
max_workers=1, thread_name_prefix="WorkerAsyncOutput")
52+
5353
self.collective_rpc("init_worker", args=([kwargs], ))
5454
self.collective_rpc("init_device")
5555
self.collective_rpc("load_model")
5656

57+
def _distributed_args(self) -> tuple[str, int, int]:
58+
"""Return (distributed_init_method, rank, local_rank)."""
59+
distributed_init_method = get_distributed_init_method(
60+
get_ip(), get_open_port())
61+
# set local rank as the device index if specified
62+
device_info = self.vllm_config.device_config.device.__str__().split(
63+
":")
64+
local_rank = int(device_info[1]) if len(device_info) > 1 else 0
65+
return distributed_init_method, 0, local_rank
66+
67+
@cached_property
68+
def max_concurrent_batches(self) -> int:
69+
return 2 if self.scheduler_config.async_scheduling else 1
70+
5771
def collective_rpc(self,
5872
method: Union[str, Callable],
5973
timeout: Optional[float] = None,
6074
args: Tuple = (),
61-
kwargs: Optional[Dict] = None) -> List[Any]:
75+
kwargs: Optional[Dict] = None,
76+
non_block: bool = False) -> List[Any]:
6277
if kwargs is None:
6378
kwargs = {}
6479
if self.mm_receiver_cache is not None and method == "execute_model":
6580
get_and_update_mm_cache(self.mm_receiver_cache, args)
66-
answer = run_method(self.driver_worker, method, args, kwargs)
67-
return [answer]
81+
82+
if not non_block:
83+
return [run_method(self.driver_worker, method, args, kwargs)]
84+
85+
try:
86+
result = run_method(self.driver_worker, method, args, kwargs)
87+
if isinstance(result, AsyncModelRunnerOutput):
88+
if (async_thread := self.async_output_thread) is not None:
89+
return [async_thread.submit(result.get_output)]
90+
result = result.get_output()
91+
future = Future[Any]()
92+
future.set_result(result)
93+
except Exception as e:
94+
future = Future[Any]()
95+
future.set_exception(e)
96+
return [future]
6897

6998
def check_health(self) -> None:
7099
# UniProcExecutor will always be healthy as long as
@@ -116,8 +145,9 @@ def _init_executor(self) -> None:
116145
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
117146
("To get deterministic execution in V1, "
118147
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
119-
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
120-
rpc_rank=0)
148+
super()._init_executor()
149+
150+
def _distributed_args(self) -> tuple[str, int, int]:
121151
# engines are launched in torchrun-compatible launchers
122152
# so we can use the env:// method.
123153
# required env vars:
@@ -128,19 +158,7 @@ def _init_executor(self) -> None:
128158
distributed_init_method = "env://"
129159
rank = int(os.environ["RANK"])
130160
local_rank = int(os.environ["LOCAL_RANK"])
131-
is_driver_worker = True
132-
kwargs = dict(
133-
vllm_config=self.vllm_config,
134-
local_rank=local_rank,
135-
rank=rank,
136-
distributed_init_method=distributed_init_method,
137-
is_driver_worker=is_driver_worker,
138-
)
139-
self.mm_receiver_cache = worker_receiver_cache_from_config(
140-
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
141-
self.collective_rpc("init_worker", args=([kwargs], ))
142-
self.collective_rpc("init_device")
143-
self.collective_rpc("load_model")
161+
return distributed_init_method, rank, local_rank
144162

145163
def determine_num_available_blocks(self) -> Tuple[int, int]:
146164
"""

vllm/v1/engine/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def __init__(self,
159159
self.request_block_hasher = get_request_block_hasher(
160160
block_size, caching_hash_fn)
161161

162+
self.step_fn = (self.step if self.batch_queue is None else
163+
self.step_with_batch_queue)
164+
162165
def _initialize_kv_caches(
163166
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
164167
start = time.time()
@@ -331,7 +334,8 @@ def step_with_batch_queue(
331334
model_executed = False
332335
if self.scheduler.has_requests():
333336
scheduler_output = self.scheduler.schedule()
334-
future = self.model_executor.execute_model(scheduler_output)
337+
future = self.model_executor.execute_model(scheduler_output,
338+
non_block=True)
335339
batch_queue.appendleft(
336340
(future, scheduler_output)) # type: ignore[arg-type]
337341

@@ -534,9 +538,6 @@ def __init__(
534538
assert addresses.coordinator_input is not None
535539
logger.info("Waiting for READY message from DP Coordinator...")
536540

537-
self.step_fn = (self.step if self.batch_queue is None else
538-
self.step_with_batch_queue)
539-
540541
# Mark the startup heap as static so that it's ignored by GC.
541542
# Reduces pause times of oldest generation collections.
542543
gc.collect()

vllm/v1/engine/core_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ def __init__(self, *args, **kwargs):
245245
self.engine_core = EngineCore(*args, **kwargs)
246246

247247
def get_output(self) -> EngineCoreOutputs:
248-
outputs, _ = self.engine_core.step()
249-
return outputs.get(0) or EngineCoreOutputs()
248+
outputs, _ = self.engine_core.step_fn()
249+
return outputs and outputs.get(0) or EngineCoreOutputs()
250250

251251
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
252252
return self.engine_core.get_supported_tasks()

vllm/v1/executor/abstract.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from concurrent.futures import Future
5-
from typing import Callable, Optional, Union
5+
from typing import Any, Callable, Optional, Union
66

77
import torch
88
import torch.distributed as dist
@@ -14,6 +14,7 @@
1414
from vllm.executor.uniproc_executor import ( # noqa
1515
UniProcExecutor as UniProcExecutorV0)
1616
from vllm.utils import resolve_obj_by_qualname
17+
from vllm.v1.core.sched.output import SchedulerOutput
1718
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
1819
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
1920

@@ -86,12 +87,22 @@ def determine_available_memory(self) -> list[int]: # in bytes
8687
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
8788
return self.collective_rpc("get_kv_cache_spec")
8889

90+
def collective_rpc(self,
91+
method: Union[str, Callable],
92+
timeout: Optional[float] = None,
93+
args: tuple = (),
94+
kwargs: Optional[dict] = None,
95+
non_block: bool = False) -> list[Any]:
96+
raise NotImplementedError
97+
8998
def execute_model(
9099
self,
91-
scheduler_output,
100+
scheduler_output: SchedulerOutput,
101+
non_block: bool = False,
92102
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
93103
output = self.collective_rpc("execute_model",
94-
args=(scheduler_output, ))
104+
args=(scheduler_output, ),
105+
non_block=non_block)
95106
return output[0]
96107

97108
def execute_dummy_batch(self) -> None:

vllm/v1/executor/multiproc_executor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from concurrent.futures import Future, ThreadPoolExecutor
1212
from dataclasses import dataclass
1313
from enum import Enum, auto
14-
from functools import partial
14+
from functools import cached_property, partial
1515
from multiprocessing.connection import Connection
1616
from multiprocessing.process import BaseProcess
1717
from multiprocessing.synchronize import Lock as LockType
@@ -37,6 +37,7 @@
3737
from vllm.utils import (decorate_logs, get_distributed_init_method,
3838
get_loopback_ip, get_mp_context, get_open_port,
3939
set_process_title)
40+
from vllm.v1.core.sched.output import SchedulerOutput
4041
from vllm.v1.executor.abstract import Executor, FailureCallback
4142
from vllm.v1.executor.utils import get_and_update_mm_cache
4243
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
@@ -174,9 +175,9 @@ def register_failure_callback(self, callback: FailureCallback):
174175

175176
def execute_model(
176177
self,
177-
scheduler_output,
178+
scheduler_output: SchedulerOutput,
179+
non_block: bool = False,
178180
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
179-
non_block = self.max_concurrent_batches > 1
180181

181182
if not self.has_connector:
182183
# get output only from a single worker (output_rank)
@@ -328,7 +329,7 @@ def check_health(self) -> None:
328329
self.collective_rpc("check_health", timeout=10)
329330
return
330331

331-
@property
332+
@cached_property
332333
def max_concurrent_batches(self) -> int:
333334
if self.scheduler_config.async_scheduling:
334335
return 2
@@ -632,7 +633,8 @@ def enqueue_output(self, output: Any):
632633
result = (WorkerProc.ResponseStatus.FAILURE, str(output))
633634
else:
634635
result = (WorkerProc.ResponseStatus.SUCCESS, output)
635-
self.worker_response_mq.enqueue(result)
636+
if (response_mq := self.worker_response_mq) is not None:
637+
response_mq.enqueue(result)
636638

637639
def handle_output(self, output: Any):
638640
"""Handles output from the worker. If async scheduling is enabled,

vllm/v1/executor/ray_distributed_executor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ def max_concurrent_batches(self) -> int:
6666
def execute_model(
6767
self,
6868
scheduler_output: SchedulerOutput,
69+
non_block: bool = False,
6970
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
7071
"""Execute the model on the Ray workers.
7172
7273
Args:
7374
scheduler_output: The scheduler output to execute.
75+
non_block: If True, the method will return a Future.
7476
7577
Returns:
7678
The model runner output.
@@ -84,15 +86,15 @@ def execute_model(
8486
if not self.has_connector:
8587
# Get output only from a single worker (output_rank)
8688
# When PP is not used, we block here until the result is available.
87-
if self.max_concurrent_batches == 1:
89+
if not non_block:
8890
return refs[0].get()
8991

9092
# When PP is used, we return a FutureWrapper immediately so that
9193
# the scheduler can yield to the next batch.
9294
return FutureWrapper(refs)
9395

9496
# Get output from all workers when connector is present
95-
if self.max_concurrent_batches == 1:
97+
if not non_block:
9698
# Block and get results from all workers
9799
outputs = [ref.get() for ref in refs]
98100
return self.kv_output_aggregator.aggregate(outputs)
@@ -106,4 +108,3 @@ def reinitialize_distributed(
106108
if reconfig_request.new_data_parallel_rank == \
107109
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
108110
self.shutdown()
109-
return

0 commit comments

Comments
 (0)