Skip to content

Commit cca2fab

Browse files
committed
Include ExecutorWithExternalLauncher
Signed-off-by: Nick Hill <[email protected]>
1 parent 79ccd33 commit cca2fab

File tree

1 file changed

+15
-22
lines changed

1 file changed

+15
-22
lines changed

vllm/executor/uniproc_executor.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,7 @@ def _init_executor(self) -> None:
3232
"""
3333
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
3434
rpc_rank=0)
35-
distributed_init_method = get_distributed_init_method(
36-
get_ip(), get_open_port())
37-
local_rank = 0
38-
# set local rank as the device index if specified
39-
device_info = self.vllm_config.device_config.device.__str__().split(
40-
":")
41-
if len(device_info) > 1:
42-
local_rank = int(device_info[1])
43-
rank = 0
35+
distributed_init_method, rank, local_rank = self._distributed_args()
4436
is_driver_worker = True
4537
kwargs = dict(
4638
vllm_config=self.vllm_config,
@@ -68,6 +60,16 @@ def _init_executor(self) -> None:
6860
self.collective_rpc("init_device")
6961
self.collective_rpc("load_model")
7062

63+
def _distributed_args(self) -> tuple[str, int, int]:
64+
"""Return (distributed_init_method, rank, local_rank)."""
65+
distributed_init_method = get_distributed_init_method(
66+
get_ip(), get_open_port())
67+
# set local rank as the device index if specified
68+
device_info = self.vllm_config.device_config.device.__str__().split(
69+
":")
70+
local_rank = int(device_info[1]) if len(device_info) > 1 else 0
71+
return distributed_init_method, 0, local_rank
72+
7173
@cached_property
7274
def max_concurrent_batches(self) -> int:
7375
return 2 if self.scheduler_config.async_scheduling else 1
@@ -162,8 +164,9 @@ def _init_executor(self) -> None:
162164
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
163165
("To get deterministic execution in V1, "
164166
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
165-
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
166-
rpc_rank=0)
167+
super()._init_executor()
168+
169+
def _distributed_args(self) -> tuple[str, int, int]:
167170
# engines are launched in torchrun-compatible launchers
168171
# so we can use the env:// method.
169172
# required env vars:
@@ -174,17 +177,7 @@ def _init_executor(self) -> None:
174177
distributed_init_method = "env://"
175178
rank = int(os.environ["RANK"])
176179
local_rank = int(os.environ["LOCAL_RANK"])
177-
is_driver_worker = True
178-
kwargs = dict(
179-
vllm_config=self.vllm_config,
180-
local_rank=local_rank,
181-
rank=rank,
182-
distributed_init_method=distributed_init_method,
183-
is_driver_worker=is_driver_worker,
184-
)
185-
self.collective_rpc("init_worker", args=([kwargs], ))
186-
self.collective_rpc("init_device")
187-
self.collective_rpc("load_model")
180+
return distributed_init_method, rank, local_rank
188181

189182
def determine_num_available_blocks(self) -> Tuple[int, int]:
190183
"""

0 commit comments

Comments
 (0)