@@ -32,15 +32,7 @@ def _init_executor(self) -> None:
3232 """
3333 self .driver_worker = WorkerWrapperBase (vllm_config = self .vllm_config ,
3434 rpc_rank = 0 )
35- distributed_init_method = get_distributed_init_method (
36- get_ip (), get_open_port ())
37- local_rank = 0
38- # set local rank as the device index if specified
39- device_info = self .vllm_config .device_config .device .__str__ ().split (
40- ":" )
41- if len (device_info ) > 1 :
42- local_rank = int (device_info [1 ])
43- rank = 0
35+ distributed_init_method , rank , local_rank = self ._distributed_args ()
4436 is_driver_worker = True
4537 kwargs = dict (
4638 vllm_config = self .vllm_config ,
@@ -68,6 +60,16 @@ def _init_executor(self) -> None:
6860 self .collective_rpc ("init_device" )
6961 self .collective_rpc ("load_model" )
7062
63+ def _distributed_args (self ) -> tuple [str , int , int ]:
64+ """Return (distributed_init_method, rank, local_rank)."""
65+ distributed_init_method = get_distributed_init_method (
66+ get_ip (), get_open_port ())
67+ # set local rank as the device index if specified
68+ device_info = self .vllm_config .device_config .device .__str__ ().split (
69+ ":" )
70+ local_rank = int (device_info [1 ]) if len (device_info ) > 1 else 0
71+ return distributed_init_method , 0 , local_rank
72+
7173 @cached_property
7274 def max_concurrent_batches (self ) -> int :
7375 return 2 if self .scheduler_config .async_scheduling else 1
@@ -162,8 +164,9 @@ def _init_executor(self) -> None:
162164 assert not envs .VLLM_ENABLE_V1_MULTIPROCESSING , \
163165 ("To get deterministic execution in V1, "
164166 "please set VLLM_ENABLE_V1_MULTIPROCESSING=0" )
165- self .driver_worker = WorkerWrapperBase (vllm_config = self .vllm_config ,
166- rpc_rank = 0 )
167+ super ()._init_executor ()
168+
169+ def _distributed_args (self ) -> tuple [str , int , int ]:
167170 # engines are launched in torchrun-compatible launchers
168171 # so we can use the env:// method.
169172 # required env vars:
@@ -174,17 +177,7 @@ def _init_executor(self) -> None:
174177 distributed_init_method = "env://"
175178 rank = int (os .environ ["RANK" ])
176179 local_rank = int (os .environ ["LOCAL_RANK" ])
177- is_driver_worker = True
178- kwargs = dict (
179- vllm_config = self .vllm_config ,
180- local_rank = local_rank ,
181- rank = rank ,
182- distributed_init_method = distributed_init_method ,
183- is_driver_worker = is_driver_worker ,
184- )
185- self .collective_rpc ("init_worker" , args = ([kwargs ], ))
186- self .collective_rpc ("init_device" )
187- self .collective_rpc ("load_model" )
180+ return distributed_init_method , rank , local_rank
188181
189182 def determine_num_available_blocks (self ) -> Tuple [int , int ]:
190183 """
0 commit comments