Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def _fix_prompt_embed_outputs(
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("async_scheduling", [True, False])
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models(
monkeypatch: pytest.MonkeyPatch,
Expand All @@ -70,13 +72,21 @@ def test_models(
backend: str,
max_tokens: int,
enforce_eager: bool,
async_scheduling: bool,
model_executor: str,
enable_prompt_embeds: bool,
) -> None:

if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")

if not envs.VLLM_USE_V1:
if async_scheduling:
pytest.skip("async_scheduling only supported in v1.")
if model_executor != "uni":
pytest.skip("only test uniproc executor for v0.")

if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip(
f"{backend} does not support gemma2 with full context length.")
Expand All @@ -98,11 +108,15 @@ def test_models(
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)

with VllmRunner(model,
max_model_len=8192,
enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model:
with VllmRunner(
model,
max_model_len=8192,
enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
async_scheduling=async_scheduling,
distributed_executor_backend=model_executor,
) as vllm_model:
if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
Expand Down
4 changes: 4 additions & 0 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,13 @@ def initialize_from_config(
def execute_model(
self,
scheduler_output,
non_block=False,
) -> Future[ModelRunnerOutput]:
"""Make execute_model non-blocking."""

# DummyExecutor used only for testing async case.
assert non_block

def _execute():
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
Expand Down
7 changes: 2 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,11 +1296,8 @@ def create_engine_config(
# Async scheduling does not work with the uniprocess backend.
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "mp"
logger.info("Using mp-based distributed executor backend "
"for async scheduling.")
if self.distributed_executor_backend == "uni":
raise ValueError("Async scheduling is not supported with "
"uni-process backend.")
logger.info("Defaulting to mp-based distributed executor "
"backend for async scheduling.")
if self.pipeline_parallel_size > 1:
raise ValueError("Async scheduling is not supported with "
"pipeline-parallel-size > 1.")
Expand Down
74 changes: 46 additions & 28 deletions vllm/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from concurrent.futures import Future, ThreadPoolExecutor
from functools import cached_property
from multiprocessing import Lock
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand All @@ -17,6 +18,7 @@
run_method)
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand All @@ -31,15 +33,7 @@ def _init_executor(self) -> None:
"""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
local_rank = 0
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":")
if len(device_info) > 1:
local_rank = int(device_info[1])
rank = 0
distributed_init_method, rank, local_rank = self._distributed_args()
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
Expand All @@ -50,21 +44,56 @@ def _init_executor(self) -> None:
)
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config, MULTIMODAL_REGISTRY, Lock())

self.async_output_thread: Optional[ThreadPoolExecutor] = None
if self.max_concurrent_batches > 1:
self.async_output_thread = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="WorkerAsyncOutput")

self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
self.collective_rpc("load_model")

def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank)."""
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":")
local_rank = int(device_info[1]) if len(device_info) > 1 else 0
return distributed_init_method, 0, local_rank

@cached_property
def max_concurrent_batches(self) -> int:
return 2 if self.scheduler_config.async_scheduling else 1

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
kwargs: Optional[Dict] = None,
non_block: bool = False) -> List[Any]:
if kwargs is None:
kwargs = {}
if self.mm_receiver_cache is not None and method == "execute_model":
get_and_update_mm_cache(self.mm_receiver_cache, args)
answer = run_method(self.driver_worker, method, args, kwargs)
return [answer]

if not non_block:
return [run_method(self.driver_worker, method, args, kwargs)]

try:
result = run_method(self.driver_worker, method, args, kwargs)
if isinstance(result, AsyncModelRunnerOutput):
if (async_thread := self.async_output_thread) is not None:
return [async_thread.submit(result.get_output)]
result = result.get_output()
future = Future[Any]()
future.set_result(result)
except Exception as e:
future = Future[Any]()
future.set_exception(e)
return [future]

def check_health(self) -> None:
# UniProcExecutor will always be healthy as long as
Expand Down Expand Up @@ -116,8 +145,9 @@ def _init_executor(self) -> None:
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
("To get deterministic execution in V1, "
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0)
super()._init_executor()

def _distributed_args(self) -> tuple[str, int, int]:
# engines are launched in torchrun-compatible launchers
# so we can use the env:// method.
# required env vars:
Expand All @@ -128,19 +158,7 @@ def _init_executor(self) -> None:
distributed_init_method = "env://"
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
return distributed_init_method, rank, local_rank

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
Expand Down
9 changes: 5 additions & 4 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def __init__(self,
self.request_block_hasher = get_request_block_hasher(
block_size, caching_hash_fn)

self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)

def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
Expand Down Expand Up @@ -331,7 +334,8 @@ def step_with_batch_queue(
model_executed = False
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output)
future = self.model_executor.execute_model(scheduler_output,
non_block=True)
batch_queue.appendleft(
(future, scheduler_output)) # type: ignore[arg-type]

Expand Down Expand Up @@ -534,9 +538,6 @@ def __init__(
assert addresses.coordinator_input is not None
logger.info("Waiting for READY message from DP Coordinator...")

self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)

# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
gc.collect()
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ def __init__(self, *args, **kwargs):
self.engine_core = EngineCore(*args, **kwargs)

def get_output(self) -> EngineCoreOutputs:
outputs, _ = self.engine_core.step()
return outputs.get(0) or EngineCoreOutputs()
outputs, _ = self.engine_core.step_fn()
return outputs and outputs.get(0) or EngineCoreOutputs()

def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
Expand Down
17 changes: 14 additions & 3 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from concurrent.futures import Future
from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union

import torch
import torch.distributed as dist
Expand All @@ -14,6 +14,7 @@
from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0)
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput

Expand Down Expand Up @@ -86,12 +87,22 @@ def determine_available_memory(self) -> list[int]: # in bytes
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
return self.collective_rpc("get_kv_cache_spec")

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False) -> list[Any]:
raise NotImplementedError

def execute_model(
self,
scheduler_output,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
args=(scheduler_output, ),
non_block=non_block)
return output[0]

def execute_dummy_batch(self) -> None:
Expand Down
12 changes: 7 additions & 5 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from functools import cached_property, partial
from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess
from multiprocessing.synchronize import Lock as LockType
Expand All @@ -37,6 +37,7 @@
from vllm.utils import (decorate_logs, get_distributed_init_method,
get_loopback_ip, get_mp_context, get_open_port,
set_process_title)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
Expand Down Expand Up @@ -174,9 +175,9 @@ def register_failure_callback(self, callback: FailureCallback):

def execute_model(
self,
scheduler_output,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
non_block = self.max_concurrent_batches > 1

if not self.has_connector:
# get output only from a single worker (output_rank)
Expand Down Expand Up @@ -328,7 +329,7 @@ def check_health(self) -> None:
self.collective_rpc("check_health", timeout=10)
return

@property
@cached_property
def max_concurrent_batches(self) -> int:
if self.scheduler_config.async_scheduling:
return 2
Expand Down Expand Up @@ -632,7 +633,8 @@ def enqueue_output(self, output: Any):
result = (WorkerProc.ResponseStatus.FAILURE, str(output))
else:
result = (WorkerProc.ResponseStatus.SUCCESS, output)
self.worker_response_mq.enqueue(result)
if (response_mq := self.worker_response_mq) is not None:
response_mq.enqueue(result)

def handle_output(self, output: Any):
"""Handles output from the worker. If async scheduling is enabled,
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ def max_concurrent_batches(self) -> int:
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
"""Execute the model on the Ray workers.

Args:
scheduler_output: The scheduler output to execute.
non_block: If True, the method will return a Future.

Returns:
The model runner output.
Expand All @@ -84,15 +86,15 @@ def execute_model(
if not self.has_connector:
# Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available.
if self.max_concurrent_batches == 1:
if not non_block:
return refs[0].get()

# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs)

# Get output from all workers when connector is present
if self.max_concurrent_batches == 1:
if not non_block:
# Block and get results from all workers
outputs = [ref.get() for ref in refs]
return self.kv_output_aggregator.aggregate(outputs)
Expand All @@ -106,4 +108,3 @@ def reinitialize_distributed(
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
self.shutdown()
return