Skip to content

Commit 06d3693

Browse files
grimoireCyCle1024
andauthored
Ray mp engine backend (#3790)
* add ray mp engine * remove flag in profile throughput * WIP * lint * remove * add sleep, wakeup & update_params method for MPEngine, call update_co… (#5) * add sleep, wakeup & update_params method for MPEngine, call update_configs in ray_executor wakeup kvcache * fix lint --------- Co-authored-by: CyCle1024 <[email protected]>
1 parent 489bb15 commit 06d3693

File tree

12 files changed

+803
-402
lines changed

12 files changed

+803
-402
lines changed

lmdeploy/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ class PytorchEngineConfig:
326326
migration_backend: migration backend. options: ['DLSlime'].
327327
Default to `MigrationBackend.DLSlime`.
328328
enable_mp_engine (bool): run engine in multi-process mode.
329+
mp_engine_backend (str): backend of mp engine, options:
330+
['mp', 'ray']. Default to `mp`.
329331
model_format (str): weight quantization policy, options: ['fp8'].
330332
hf_overrides (Dict[str, Any]): Huggingface overrides for the model.
331333
It can be used to override the default config of the model,
@@ -359,6 +361,7 @@ class PytorchEngineConfig:
359361
enable_microbatch: bool = False
360362
enable_eplb: bool = False
361363
enable_mp_engine: bool = False
364+
mp_engine_backend: str = 'mp'
362365
model_format: str = None
363366
enable_metrics: bool = False
364367
hf_overrides: Optional[Dict[str, Any]] = None

lmdeploy/pytorch/engine/base.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,
3+
DistServeInitRequest)
4+
5+
6+
class EngineBase:
7+
8+
def close(self) -> None:
9+
"""Close mp engine."""
10+
raise NotImplementedError('This method is not implemented.')
11+
12+
def start_loop(self) -> None:
13+
"""Start mp engine loop."""
14+
15+
def end_session(self, session_id: int):
16+
"""End session."""
17+
raise NotImplementedError('This method is not implemented.')
18+
19+
def p2p_initialize(self, conn_request: DistServeInitRequest):
20+
"""Init rdma link."""
21+
raise NotImplementedError('This method is not implemented.')
22+
23+
def p2p_connect(self, conn_request: DistServeConnectionRequest):
24+
"""rdma_connect."""
25+
raise NotImplementedError('This method is not implemented.')
26+
27+
def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
28+
"""Drop connection.
29+
30+
1. drop engine connection (zmq connection)
31+
2. TODO(JimyMa) drop RDMA Connection.
32+
"""
33+
raise NotImplementedError('This method is not implemented.')
34+
35+
def create_instance(self, cuda_stream_id=0):
36+
"""Create instance."""
37+
raise NotImplementedError('This method is not implemented.')
38+
39+
40+
class EngineInstanceBase:
41+
42+
async def async_end(self, session_id: int):
43+
"""End the given session."""
44+
raise NotImplementedError('This method is not implemented.')
45+
46+
async def async_cancel(self, session_id: int):
47+
"""Stop current streaming inference."""
48+
raise NotImplementedError('This method is not implemented.')
49+
50+
async def async_stream_infer(self, *args, **kwargs):
51+
"""Send stream inference request."""
52+
raise NotImplementedError('This method is not implemented.')

lmdeploy/pytorch/engine/engine.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ..messages import MessageStatus, SchedulerSequence
2424
from ..model_inputs import ModelInputs, VisionModelInputs
2525
from ..paging import Scheduler
26+
from .base import EngineBase
2627
from .engine_checker import EngineChecker
2728
from .executor import build_executor
2829
from .logits_process import SamplingInputs
@@ -308,7 +309,7 @@ def build_inputs_maker(engine: 'Engine'):
308309
return InputsMakerAsync(engine)
309310

310311

311-
class Engine:
312+
class Engine(EngineBase):
312313
"""The inference engine of lmdeploy pytorch.
313314
314315
Args:
@@ -425,11 +426,13 @@ def from_pretrained(cls,
425426
trust_remote_code (bool): Trust remote code
426427
"""
427428
if engine_config is not None and engine_config.enable_mp_engine:
428-
from .mp_engine.mp_engine import MPEngine
429-
return MPEngine(model_path=pretrained_model_name_or_path,
430-
tokenizer=tokenizer,
431-
engine_config=engine_config,
432-
trust_remote_code=trust_remote_code)
429+
from .mp_engine import build_mp_engine
430+
backend = engine_config.mp_engine_backend
431+
return build_mp_engine(backend=backend,
432+
model_path=pretrained_model_name_or_path,
433+
tokenizer=tokenizer,
434+
engine_config=engine_config,
435+
trust_remote_code=trust_remote_code)
433436
if len(kwargs) > 0:
434437
logger.debug(f'Get unexpected kwargs: {kwargs}')
435438
return cls(model_path=pretrained_model_name_or_path,

lmdeploy/pytorch/engine/engine_instance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from lmdeploy.utils import get_logger
66

77
from ..messages import SamplingParam
8+
from .base import EngineInstanceBase
89
from .engine import Engine
910
from .request import RequestSender, RequestType, Response, ResponseType
1011

@@ -71,7 +72,7 @@ def cancel(req_sender: RequestSender, session_id: int):
7172
f'Error: {resp.type}.'))
7273

7374

74-
class EngineInstance:
75+
class EngineInstance(EngineInstanceBase):
7576
"""Instance of TurboMind.
7677
7778
Args:

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 7 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import contextlib
44
import json
55
import os
6-
import time
76
from typing import Any, Dict, List, Optional, Tuple
87

98
import numpy as np
@@ -19,6 +18,7 @@
1918
from lmdeploy.pytorch.devices import DeviceContext, get_device_manager
2019
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo
2120
from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch
21+
from lmdeploy.pytorch.ray import RayContext
2222
from lmdeploy.utils import get_logger, try_import_deeplink
2323

2424
from .base import ExecutorBase
@@ -27,8 +27,6 @@
2727

2828
logger = get_logger('lmdeploy')
2929

30-
PG_WAIT_TIMEOUT = 1800
31-
3230

3331
def get_device_str():
3432
"""Get device str."""
@@ -43,109 +41,6 @@ def get_device_str():
4341
return device_type
4442

4543

46-
def _wait_until_pg_ready(current_placement_group: 'PlacementGroup'):
47-
"""Wait until a placement group is ready.
48-
49-
It prints the informative log messages if the placement group is not created within time.
50-
"""
51-
# copy from vLLM
52-
# Wait until PG is ready - this will block until all
53-
# requested resources are available, and will timeout
54-
# if they cannot be provisioned.
55-
placement_group_specs = current_placement_group.bundle_specs
56-
57-
s = time.time()
58-
pg_ready_ref = current_placement_group.ready()
59-
wait_interval = 10
60-
while time.time() - s < PG_WAIT_TIMEOUT:
61-
ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
62-
if len(ready) > 0:
63-
break
64-
65-
# Exponential backoff for warning print.
66-
wait_interval *= 2
67-
logger.info(
68-
'Waiting for creating a placement group of specs for '
69-
'%d seconds. specs=%s. Check '
70-
'`ray status` to see if you have enough resources,'
71-
' and make sure the IP addresses used by ray cluster'
72-
' are the same as VLLM_HOST_IP environment variable'
73-
' specified in each node if you are running on a multi-node.', int(time.time() - s), placement_group_specs)
74-
75-
try:
76-
ray.get(pg_ready_ref, timeout=0)
77-
except ray.exceptions.GetTimeoutError:
78-
raise ValueError('Cannot provide a placement group of '
79-
f'{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See '
80-
'`ray status` to make sure the cluster has enough resources.') from None
81-
82-
83-
def _get_obj_store_memory(dp: int = 1):
84-
"""Get obj store memory."""
85-
import psutil
86-
DEFAULT_OBJECT_STORE_MEMORY_PROPORTION = os.getenv('RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION', '0.3')
87-
DEFAULT_OBJECT_STORE_MEMORY_PROPORTION = float(DEFAULT_OBJECT_STORE_MEMORY_PROPORTION)
88-
DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = os.getenv('RAY_DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES', None)
89-
if DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES is None:
90-
DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = 80 * (10**9)
91-
else:
92-
DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = int(DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES)
93-
total_mem = psutil.virtual_memory().total
94-
obj_store_mem = int(total_mem * DEFAULT_OBJECT_STORE_MEMORY_PROPORTION)
95-
obj_store_mem = min(DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES, obj_store_mem)
96-
if dp > 1:
97-
obj_store_mem = obj_store_mem // min(8, dp)
98-
return obj_store_mem
99-
100-
101-
def init_ray_cluster(world_size: int, ray_address: str = None, dp: int = 1):
102-
"""Init ray cluster."""
103-
# modifier from vLLM
104-
if not ray.is_initialized():
105-
try:
106-
num_cpus = world_size
107-
object_store_memory = _get_obj_store_memory(dp=dp)
108-
ray.init(address=ray_address,
109-
ignore_reinit_error=True,
110-
num_cpus=num_cpus,
111-
object_store_memory=object_store_memory)
112-
except ValueError as e:
113-
if e.args is not None and len(e.args) >= 1 and e.args[
114-
0] == 'When connecting to an existing cluster, num_cpus and num_gpus must not be provided.':
115-
ray.init(address=ray_address, ignore_reinit_error=True)
116-
else:
117-
raise
118-
119-
device_str = get_device_str()
120-
121-
# Create placement group for worker processes
122-
current_placement_group = ray.util.get_current_placement_group()
123-
if not current_placement_group:
124-
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
125-
if world_size > num_devices_in_cluster:
126-
logger.warning(
127-
'The number of required %ss exceeds the total '
128-
'number of available %ss in the placement group.', device_str, device_str)
129-
# Create a new placement group
130-
placement_group_specs: List[Dict[str, float]] = ([{device_str: 1.0} for _ in range(world_size)])
131-
132-
# gcs_addr = ray.get_runtime_context().gcs_address
133-
# master_addr = gcs_addr.split(':')[0]
134-
# current_ip = master_addr
135-
# # This way, at least bundle is required to be created in a current
136-
# # node.
137-
# placement_group_specs[0][f'node:{current_ip}'] = 0.001
138-
139-
# By default, Ray packs resources as much as possible.
140-
current_placement_group = ray.util.placement_group(placement_group_specs, strategy='PACK')
141-
_wait_until_pg_ready(current_placement_group)
142-
143-
assert current_placement_group is not None
144-
# Set the placement group in the parallel config
145-
placement_group = current_placement_group
146-
return placement_group
147-
148-
14944
def _get_master_addr():
15045
"""Get master addr."""
15146
addr = _envs.dist_master_addr
@@ -379,7 +274,8 @@ def __init__(self,
379274
ray_world_size = self.world_size
380275
if self.dp > 1:
381276
ray_world_size = 1
382-
placement_group = init_ray_cluster(ray_world_size, dp=dist_config.dp)
277+
self.ray_ctx = RayContext(ray_world_size, dp=dist_config.dp, device_type=device_type)
278+
placement_group = self.ray_ctx.get_placement_group()
383279
self.placement_group = placement_group
384280

385281
if self.dp == 1:
@@ -476,6 +372,8 @@ def sleep(self, level: int = 1):
476372

477373
def wakeup(self, tags: Optional[List[str]] = None):
478374
"""Wakeup."""
375+
if tags is None or 'kv_cache' in tags:
376+
self.update_configs()
479377
self.collective_rpc('wakeup', (tags, ))
480378

481379
def get_input_processor(self):
@@ -537,10 +435,7 @@ def release(self):
537435
else:
538436
[ray.kill(worker) for worker in self.workers]
539437

540-
ray.util.remove_placement_group(self.placement_group)
541-
logger.debug('RayExecutor placement group removed.')
542-
ray.shutdown()
543-
logger.debug('Ray shutdown.')
438+
self.ray_ctx.shutdown()
544439

545440
def _compile_dag(self):
546441
"""Compile dag."""
@@ -653,7 +548,7 @@ def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict
653548
runtime_env = _update_runtime_env_nsys(runtime_env)
654549
worker = ray.remote(
655550
num_cpus=0,
656-
num_gpus=1.0,
551+
num_gpus=0.01,
657552
scheduling_strategy=scheduling_strategy,
658553
runtime_env=runtime_env,
659554
)(RayWorkerWrapper).remote(**worker_kwargs)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,18 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from lmdeploy.messages import PytorchEngineConfig
3+
4+
5+
def build_mp_engine(backend: str,
6+
model_path: str,
7+
tokenizer: object,
8+
engine_config: PytorchEngineConfig = None,
9+
**kwargs):
10+
"""Build mp engine."""
11+
if backend == 'mp':
12+
from .zmq_engine import ZMQMPEngine
13+
return ZMQMPEngine(model_path, tokenizer, engine_config=engine_config, **kwargs)
14+
elif backend == 'ray':
15+
from .ray_engine import RayMPEngine
16+
return RayMPEngine(model_path, tokenizer, engine_config=engine_config, **kwargs)
17+
else:
18+
raise ValueError(f'Unsupported backend: {backend}')

0 commit comments

Comments
 (0)