11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
43import os
4+ from concurrent .futures import Future , ThreadPoolExecutor
5+ from functools import cached_property
56from multiprocessing import Lock
67from typing import Any , Callable , Dict , List , Optional , Tuple , Union
78
1718 run_method )
1819from vllm .v1 .engine import ReconfigureDistributedRequest , ReconfigureRankType
1920from vllm .v1 .executor .utils import get_and_update_mm_cache
21+ from vllm .v1 .outputs import AsyncModelRunnerOutput
2022from vllm .worker .worker_base import WorkerWrapperBase
2123
2224logger = init_logger (__name__ )
@@ -31,15 +33,7 @@ def _init_executor(self) -> None:
3133 """
3234 self .driver_worker = WorkerWrapperBase (vllm_config = self .vllm_config ,
3335 rpc_rank = 0 )
34- distributed_init_method = get_distributed_init_method (
35- get_ip (), get_open_port ())
36- local_rank = 0
37- # set local rank as the device index if specified
38- device_info = self .vllm_config .device_config .device .__str__ ().split (
39- ":" )
40- if len (device_info ) > 1 :
41- local_rank = int (device_info [1 ])
42- rank = 0
36+ distributed_init_method , rank , local_rank = self ._distributed_args ()
4337 is_driver_worker = True
4438 kwargs = dict (
4539 vllm_config = self .vllm_config ,
@@ -50,21 +44,56 @@ def _init_executor(self) -> None:
5044 )
5145 self .mm_receiver_cache = worker_receiver_cache_from_config (
5246 self .vllm_config , MULTIMODAL_REGISTRY , Lock ())
47+
48+ self .async_output_thread : Optional [ThreadPoolExecutor ] = None
49+ if self .max_concurrent_batches > 1 :
50+ self .async_output_thread = ThreadPoolExecutor (
51+ max_workers = 1 , thread_name_prefix = "WorkerAsyncOutput" )
52+
5353 self .collective_rpc ("init_worker" , args = ([kwargs ], ))
5454 self .collective_rpc ("init_device" )
5555 self .collective_rpc ("load_model" )
5656
57+ def _distributed_args (self ) -> tuple [str , int , int ]:
58+ """Return (distributed_init_method, rank, local_rank)."""
59+ distributed_init_method = get_distributed_init_method (
60+ get_ip (), get_open_port ())
61+ # set local rank as the device index if specified
62+ device_info = self .vllm_config .device_config .device .__str__ ().split (
63+ ":" )
64+ local_rank = int (device_info [1 ]) if len (device_info ) > 1 else 0
65+ return distributed_init_method , 0 , local_rank
66+
67+ @cached_property
68+ def max_concurrent_batches (self ) -> int :
69+ return 2 if self .scheduler_config .async_scheduling else 1
70+
5771 def collective_rpc (self ,
5872 method : Union [str , Callable ],
5973 timeout : Optional [float ] = None ,
6074 args : Tuple = (),
61- kwargs : Optional [Dict ] = None ) -> List [Any ]:
75+ kwargs : Optional [Dict ] = None ,
76+ non_block : bool = False ) -> List [Any ]:
6277 if kwargs is None :
6378 kwargs = {}
6479 if self .mm_receiver_cache is not None and method == "execute_model" :
6580 get_and_update_mm_cache (self .mm_receiver_cache , args )
66- answer = run_method (self .driver_worker , method , args , kwargs )
67- return [answer ]
81+
82+ if not non_block :
83+ return [run_method (self .driver_worker , method , args , kwargs )]
84+
85+ try :
86+ result = run_method (self .driver_worker , method , args , kwargs )
87+ if isinstance (result , AsyncModelRunnerOutput ):
88+ if (async_thread := self .async_output_thread ) is not None :
89+ return [async_thread .submit (result .get_output )]
90+ result = result .get_output ()
91+ future = Future [Any ]()
92+ future .set_result (result )
93+ except Exception as e :
94+ future = Future [Any ]()
95+ future .set_exception (e )
96+ return [future ]
6897
6998 def check_health (self ) -> None :
7099 # UniProcExecutor will always be healthy as long as
@@ -116,8 +145,9 @@ def _init_executor(self) -> None:
116145 assert not envs .VLLM_ENABLE_V1_MULTIPROCESSING , \
117146 ("To get deterministic execution in V1, "
118147 "please set VLLM_ENABLE_V1_MULTIPROCESSING=0" )
119- self .driver_worker = WorkerWrapperBase (vllm_config = self .vllm_config ,
120- rpc_rank = 0 )
148+ super ()._init_executor ()
149+
150+ def _distributed_args (self ) -> tuple [str , int , int ]:
121151 # engines are launched in torchrun-compatible launchers
122152 # so we can use the env:// method.
123153 # required env vars:
@@ -128,19 +158,7 @@ def _init_executor(self) -> None:
128158 distributed_init_method = "env://"
129159 rank = int (os .environ ["RANK" ])
130160 local_rank = int (os .environ ["LOCAL_RANK" ])
131- is_driver_worker = True
132- kwargs = dict (
133- vllm_config = self .vllm_config ,
134- local_rank = local_rank ,
135- rank = rank ,
136- distributed_init_method = distributed_init_method ,
137- is_driver_worker = is_driver_worker ,
138- )
139- self .mm_receiver_cache = worker_receiver_cache_from_config (
140- self .vllm_config , MULTIMODAL_REGISTRY , Lock ())
141- self .collective_rpc ("init_worker" , args = ([kwargs ], ))
142- self .collective_rpc ("init_device" )
143- self .collective_rpc ("load_model" )
161+ return distributed_init_method , rank , local_rank
144162
145163 def determine_num_available_blocks (self ) -> Tuple [int , int ]:
146164 """
0 commit comments