3
3
import contextlib
4
4
import json
5
5
import os
6
- import time
7
6
from typing import Any , Dict , List , Optional , Tuple
8
7
9
8
import numpy as np
19
18
from lmdeploy .pytorch .devices import DeviceContext , get_device_manager
20
19
from lmdeploy .pytorch .disagg .conn .protocol import DistServeInitRequest , DistServeKVTransferEndpointInfo
21
20
from lmdeploy .pytorch .disagg .messages import MigrationExecutionBatch
21
+ from lmdeploy .pytorch .ray import RayContext
22
22
from lmdeploy .utils import get_logger , try_import_deeplink
23
23
24
24
from .base import ExecutorBase
27
27
28
28
logger = get_logger ('lmdeploy' )
29
29
30
- PG_WAIT_TIMEOUT = 1800
31
-
32
30
33
31
def get_device_str ():
34
32
"""Get device str."""
@@ -43,109 +41,6 @@ def get_device_str():
43
41
return device_type
44
42
45
43
46
- def _wait_until_pg_ready (current_placement_group : 'PlacementGroup' ):
47
- """Wait until a placement group is ready.
48
-
49
- It prints the informative log messages if the placement group is not created within time.
50
- """
51
- # copy from vLLM
52
- # Wait until PG is ready - this will block until all
53
- # requested resources are available, and will timeout
54
- # if they cannot be provisioned.
55
- placement_group_specs = current_placement_group .bundle_specs
56
-
57
- s = time .time ()
58
- pg_ready_ref = current_placement_group .ready ()
59
- wait_interval = 10
60
- while time .time () - s < PG_WAIT_TIMEOUT :
61
- ready , _ = ray .wait ([pg_ready_ref ], timeout = wait_interval )
62
- if len (ready ) > 0 :
63
- break
64
-
65
- # Exponential backoff for warning print.
66
- wait_interval *= 2
67
- logger .info (
68
- 'Waiting for creating a placement group of specs for '
69
- '%d seconds. specs=%s. Check '
70
- '`ray status` to see if you have enough resources,'
71
- ' and make sure the IP addresses used by ray cluster'
72
- ' are the same as VLLM_HOST_IP environment variable'
73
- ' specified in each node if you are running on a multi-node.' , int (time .time () - s ), placement_group_specs )
74
-
75
- try :
76
- ray .get (pg_ready_ref , timeout = 0 )
77
- except ray .exceptions .GetTimeoutError :
78
- raise ValueError ('Cannot provide a placement group of '
79
- f'{ placement_group_specs = } within { PG_WAIT_TIMEOUT } seconds. See '
80
- '`ray status` to make sure the cluster has enough resources.' ) from None
81
-
82
-
83
- def _get_obj_store_memory (dp : int = 1 ):
84
- """Get obj store memory."""
85
- import psutil
86
- DEFAULT_OBJECT_STORE_MEMORY_PROPORTION = os .getenv ('RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION' , '0.3' )
87
- DEFAULT_OBJECT_STORE_MEMORY_PROPORTION = float (DEFAULT_OBJECT_STORE_MEMORY_PROPORTION )
88
- DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = os .getenv ('RAY_DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES' , None )
89
- if DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES is None :
90
- DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = 80 * (10 ** 9 )
91
- else :
92
- DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = int (DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES )
93
- total_mem = psutil .virtual_memory ().total
94
- obj_store_mem = int (total_mem * DEFAULT_OBJECT_STORE_MEMORY_PROPORTION )
95
- obj_store_mem = min (DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES , obj_store_mem )
96
- if dp > 1 :
97
- obj_store_mem = obj_store_mem // min (8 , dp )
98
- return obj_store_mem
99
-
100
-
101
- def init_ray_cluster (world_size : int , ray_address : str = None , dp : int = 1 ):
102
- """Init ray cluster."""
103
- # modifier from vLLM
104
- if not ray .is_initialized ():
105
- try :
106
- num_cpus = world_size
107
- object_store_memory = _get_obj_store_memory (dp = dp )
108
- ray .init (address = ray_address ,
109
- ignore_reinit_error = True ,
110
- num_cpus = num_cpus ,
111
- object_store_memory = object_store_memory )
112
- except ValueError as e :
113
- if e .args is not None and len (e .args ) >= 1 and e .args [
114
- 0 ] == 'When connecting to an existing cluster, num_cpus and num_gpus must not be provided.' :
115
- ray .init (address = ray_address , ignore_reinit_error = True )
116
- else :
117
- raise
118
-
119
- device_str = get_device_str ()
120
-
121
- # Create placement group for worker processes
122
- current_placement_group = ray .util .get_current_placement_group ()
123
- if not current_placement_group :
124
- num_devices_in_cluster = ray .cluster_resources ().get (device_str , 0 )
125
- if world_size > num_devices_in_cluster :
126
- logger .warning (
127
- 'The number of required %ss exceeds the total '
128
- 'number of available %ss in the placement group.' , device_str , device_str )
129
- # Create a new placement group
130
- placement_group_specs : List [Dict [str , float ]] = ([{device_str : 1.0 } for _ in range (world_size )])
131
-
132
- # gcs_addr = ray.get_runtime_context().gcs_address
133
- # master_addr = gcs_addr.split(':')[0]
134
- # current_ip = master_addr
135
- # # This way, at least bundle is required to be created in a current
136
- # # node.
137
- # placement_group_specs[0][f'node:{current_ip}'] = 0.001
138
-
139
- # By default, Ray packs resources as much as possible.
140
- current_placement_group = ray .util .placement_group (placement_group_specs , strategy = 'PACK' )
141
- _wait_until_pg_ready (current_placement_group )
142
-
143
- assert current_placement_group is not None
144
- # Set the placement group in the parallel config
145
- placement_group = current_placement_group
146
- return placement_group
147
-
148
-
149
44
def _get_master_addr ():
150
45
"""Get master addr."""
151
46
addr = _envs .dist_master_addr
@@ -379,7 +274,8 @@ def __init__(self,
379
274
ray_world_size = self .world_size
380
275
if self .dp > 1 :
381
276
ray_world_size = 1
382
- placement_group = init_ray_cluster (ray_world_size , dp = dist_config .dp )
277
+ self .ray_ctx = RayContext (ray_world_size , dp = dist_config .dp , device_type = device_type )
278
+ placement_group = self .ray_ctx .get_placement_group ()
383
279
self .placement_group = placement_group
384
280
385
281
if self .dp == 1 :
@@ -476,6 +372,8 @@ def sleep(self, level: int = 1):
476
372
477
373
def wakeup (self , tags : Optional [List [str ]] = None ):
478
374
"""Wakeup."""
375
+ if tags is None or 'kv_cache' in tags :
376
+ self .update_configs ()
479
377
self .collective_rpc ('wakeup' , (tags , ))
480
378
481
379
def get_input_processor (self ):
@@ -537,10 +435,7 @@ def release(self):
537
435
else :
538
436
[ray .kill (worker ) for worker in self .workers ]
539
437
540
- ray .util .remove_placement_group (self .placement_group )
541
- logger .debug ('RayExecutor placement group removed.' )
542
- ray .shutdown ()
543
- logger .debug ('Ray shutdown.' )
438
+ self .ray_ctx .shutdown ()
544
439
545
440
def _compile_dag (self ):
546
441
"""Compile dag."""
@@ -653,7 +548,7 @@ def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict
653
548
runtime_env = _update_runtime_env_nsys (runtime_env )
654
549
worker = ray .remote (
655
550
num_cpus = 0 ,
656
- num_gpus = 1.0 ,
551
+ num_gpus = 0.01 ,
657
552
scheduling_strategy = scheduling_strategy ,
658
553
runtime_env = runtime_env ,
659
554
)(RayWorkerWrapper ).remote (** worker_kwargs )
0 commit comments