Skip to content

Commit 905d89e

Browse files
authored
[Feature] support model weight update in ep (#3765)
* support model weight update in ep * support model weight update in ep * support model weight update in ep * support model weight update in ep * Update fused_moe_backend_base.py * Update worker_process.py * Update worker_process.py * Update dynamic_weight_manager.py
1 parent 1908465 commit 905d89e

File tree

5 files changed

+46
-21
lines changed

5 files changed

+46
-21
lines changed

fastdeploy/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ def set_tp_group(self):
351351
)
352352
)
353353
# same ep group id
354-
# (TODO:gaoziyuan move this gid config to ep.py)
355354
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
355+
self.ep_group = dist.new_group(range(self.expert_parallel_size))
356356
logger.info(
357357
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
358358
)

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
splitwise_role: str,
7979
moe_phase: MoEPhase,
8080
async_finish: bool = False,
81+
group=None,
8182
):
8283
"""
8384
Initialize the DeepEP engine.
@@ -90,7 +91,9 @@ def __init__(
9091
num_experts: The number of experts.
9192
"""
9293
# TODO(@wufeisheng): Support configurable EP size​
93-
self.group = paddle.distributed.new_group(range(ep_size))
94+
if group is None:
95+
group = paddle.distributed.new_group(range(ep_size))
96+
self.group = group
9497
self.ep_size = ep_size
9598
self.rank_id = ep_rank
9699
self.hidden = hidden
@@ -277,6 +280,7 @@ def __init__(
277280
ep_size: int = 1,
278281
ep_rank: int = 0,
279282
redundant_experts_num: int = 0,
283+
ep_group=None,
280284
):
281285
self.top_k = top_k
282286
self.num_experts = num_experts
@@ -289,6 +293,7 @@ def __init__(
289293
ep_rank=ep_rank,
290294
splitwise_role=splitwise_role,
291295
moe_phase=moe_phase,
296+
group=ep_group,
292297
)
293298

294299
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
@@ -367,6 +372,7 @@ def __init__(
367372
ep_size: int = 1,
368373
ep_rank: int = 0,
369374
redundant_experts_num: int = 0,
375+
ep_group=None,
370376
moe_phase: MoEPhase = MoEPhase("prefill"),
371377
):
372378
super().__init__(
@@ -379,6 +385,7 @@ def __init__(
379385
ep_size=ep_size,
380386
ep_rank=ep_rank,
381387
redundant_experts_num=redundant_experts_num,
388+
ep_group=ep_group,
382389
)
383390

384391
def dispatch(
@@ -445,6 +452,7 @@ def __init__(
445452
ep_size: int = 1,
446453
ep_rank: int = 0,
447454
redundant_experts_num: int = 0,
455+
ep_group=None,
448456
moe_phase: MoEPhase = MoEPhase("decode"),
449457
):
450458
super().__init__(
@@ -457,6 +465,7 @@ def __init__(
457465
ep_size=ep_size,
458466
ep_rank=ep_rank,
459467
redundant_experts_num=redundant_experts_num,
468+
ep_group=ep_group,
460469
)
461470

462471
def dispatch(

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def init_ep(self, layer: nn.Layer) -> None:
5858
layer.ep_size,
5959
layer.ep_rank,
6060
layer.fd_config.model_config.redundant_experts_num,
61+
ep_group=layer.fd_config.parallel_config.ep_group,
6162
)
6263
self.ep_decoder_runner = EPDecoderRunner(
6364
layer.top_k,
@@ -68,6 +69,7 @@ def init_ep(self, layer: nn.Layer) -> None:
6869
layer.ep_size,
6970
layer.ep_rank,
7071
layer.fd_config.model_config.redundant_experts_num,
72+
ep_group=layer.fd_config.parallel_config.ep_group,
7173
)
7274
else:
7375
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
@@ -82,6 +84,7 @@ def init_ep(self, layer: nn.Layer) -> None:
8284
layer.ep_size,
8385
layer.ep_rank,
8486
layer.fd_config.model_config.redundant_experts_num,
87+
ep_group=layer.fd_config.parallel_config.ep_group,
8588
)
8689
else:
8790
from .ep import EPDecoderRunner
@@ -95,6 +98,7 @@ def init_ep(self, layer: nn.Layer) -> None:
9598
layer.ep_size,
9699
layer.ep_rank,
97100
layer.fd_config.model_config.redundant_experts_num,
101+
ep_group=layer.fd_config.parallel_config.ep_group,
98102
)
99103

100104
def process_loaded_weights(self, layer, weights) -> None:

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def update_parameters(self, pid: int = 0) -> None:
6363
paddle.device.cuda.empty_cache()
6464

6565
if not self.first_load:
66-
paddle.distributed.restart_process_group()
66+
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
67+
if self.parallel_config.enable_expert_parallel:
68+
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
6769

6870
strategy_handlers = {
6971
"ipc_snapshot": self._update_ipc_snapshot,
@@ -110,9 +112,12 @@ def clear_parameters(self, pid: int = 0) -> None:
110112
param._clear_data()
111113

112114
self._verify_parameters("clearance")
113-
if self.nranks > 1:
114-
paddle.distributed.barrier()
115-
paddle.distributed.shutdown_process_group()
115+
if self.parallel_config.tensor_parallel_size > 1:
116+
paddle.distributed.barrier(self.parallel_config.tp_group)
117+
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
118+
if self.parallel_config.enable_expert_parallel:
119+
paddle.distributed.barrier(self.parallel_config.ep_group)
120+
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
116121
self._update_shared_status(pid, -2)
117122

118123
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
@@ -141,8 +146,8 @@ def _validate_parameter_match(self, name: str, src: paddle.Tensor, dst: paddle.T
141146
def _finalize_update(self, pid: int):
142147
"""Finalize update process with verification."""
143148
self._verify_parameters("update")
144-
if self.nranks > 1:
145-
paddle.distributed.barrier()
149+
if self.parallel_config.tensor_parallel_size > 1:
150+
paddle.distributed.barrier(self.parallel_config.tp_group)
146151
if not self.first_load:
147152
self._update_shared_status(pid, 0)
148153
self.first_load = False

fastdeploy/worker/worker_process.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -254,27 +254,26 @@ def event_loop_normal(self) -> None:
254254
"""
255255
# Currently, only support single node
256256
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
257-
mp_num_per_node = self.parallel_config.tensor_parallel_size // self.nnode
258257
req_ids = []
259258
num_running_requests = 0
260-
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
259+
260+
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
261261
while True:
262-
if self.local_rank == 0:
262+
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
263263
if self.model_weights_status.value[0] != 0:
264-
self.exist_task_signal.value[0] = 2
265-
else:
266-
self.exist_task_signal.value[0] = 0
267-
268-
if self.parallel_config.tensor_parallel_size > 1:
269-
# Synchronize before updating weights
270-
paddle.distributed.barrier(self.parallel_config.tp_group)
264+
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
265+
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
266+
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
267+
if self.fd_config.load_config.dynamic_load_weight:
268+
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group)
271269

272270
self.insert_step = False
273271
req_dicts = None
272+
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
274273
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
275274

276275
# The first worker detects whether there are tasks in the task queue
277-
if self.local_rank % mp_num_per_node == 0:
276+
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
278277
if self.task_queue.num_tasks() > 0:
279278
# VL only support 1 batch to prefill
280279
if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
@@ -290,16 +289,24 @@ def event_loop_normal(self) -> None:
290289
paddle.distributed.barrier(self.parallel_config.tp_group)
291290

292291
if self.fd_config.load_config.dynamic_load_weight:
293-
if self.exist_task_signal.value[0] == 2:
292+
if self.parallel_config.enable_expert_parallel:
293+
paddle.distributed.barrier(self.parallel_config.ep_group)
294+
else:
295+
paddle.distributed.barrier(self.parallel_config.tp_group)
296+
if self.model_weights_signal[0] != 0:
297+
logger.info(f"Rank: {self.local_rank} has updated parameters.")
294298
from fastdeploy.rl.dynamic_weight_manager import (
295299
DynamicWeightManager,
296300
)
297301

302+
self.model_weights_status.value[0] = self.model_weights_signal[0]
298303
DynamicWeightManager.check_model_weights_status(
299304
self.model_weights_status,
305+
# model_weights_signal
300306
self.worker.model_runner,
301-
self.parallel_config.engine_pid,
307+
self.parallel_config.engine_worker_queue_port,
302308
)
309+
self.model_weights_signal[0] = 0
303310

304311
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
305312
logger.info(f"Rank: {self.local_rank} Detected new requests.")

0 commit comments

Comments
 (0)