Skip to content

Commit 90f3209

Browse files
authored
support both eplb and microbatch simultaneously (#3591)
1 parent 8266648 commit 90f3209

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

lmdeploy/pytorch/models/deepseek_v2.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,11 @@ def forward(
577577
class MoEGate(nn.Module):
578578
"""Deepseek Gate."""
579579

580-
def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
580+
def __init__(self,
581+
config: Any,
582+
dtype: torch.dtype = None,
583+
device: torch.device = None,
584+
info: eplb.EPLBDispatchInfo = None):
581585
super().__init__()
582586
self.config = config
583587
self.top_k = config.num_experts_per_tok
@@ -602,6 +606,7 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device
602606
self.softmax_topk = SoftmaxTopK(self.top_k)
603607

604608
self.fake_eplb = getenv('LMDEPLOY_FAKE_EPLB', 'False').lower() == 'true'
609+
self.eplb_dispatch_info = info
605610

606611
def _compute_scores(self, logits: torch.Tensor):
607612
"""compute scores."""
@@ -665,6 +670,9 @@ def forward(self, hidden_states: torch.Tensor):
665670
if not self.renormalize or self.topk_method == 'noaux_tc':
666671
topk_weight = topk_weight * self.routed_scaling_factor
667672

673+
if self.eplb_dispatch_info is not None:
674+
topk_idx = eplb.topk_ids_logical_to_physical(topk_idx, self.eplb_dispatch_info)
675+
668676
return topk_weight, topk_idx
669677

670678

@@ -685,18 +693,19 @@ def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: to
685693
self.n_group = config.n_group
686694
self.topk_group = config.topk_group
687695

688-
self.gate = MoEGate(config, dtype=dtype, device=device)
689-
690696
dist_ctx = get_dist_manager().current_context()
691697
dp = dist_ctx.dp
692698
world_size = dist_ctx.world_size
693699
moe_all_reduce = dp > 1 and dist_ctx.tp > 1
694700
if get_dist_manager().current_context().dist_config.enable_eplb:
695-
self.eplb_dispatch_info = eplb.EPLBDispatchInfo.init_new(
701+
eplb_dispatch_info = eplb.EPLBDispatchInfo.init_new(
696702
ep_rank=dist_ctx.ep_rank,
697703
layer_idx=layer_idx,
698704
)
699705
self.num_experts = eplb.get_global_eplb_metadata().num_physical_experts()
706+
self.gate = MoEGate(config, dtype=dtype, device=device, info=eplb_dispatch_info)
707+
else:
708+
self.gate = MoEGate(config, dtype=dtype, device=device, info=None)
700709
self.experts = build_fused_moe(
701710
self.hidden_dim,
702711
self.ffn_dim,
@@ -730,9 +739,6 @@ def forward(self, hidden_states: torch.Tensor):
730739
batch_size, sequence_length, hidden_dim = hidden_states.shape
731740
hidden_states = hidden_states.view(-1, hidden_dim)
732741
topk_weights, topk_ids = self.gate(hidden_states)
733-
if get_dist_manager().current_context().dist_config.enable_eplb:
734-
topk_ids = eplb.topk_ids_logical_to_physical(topk_ids, self.eplb_dispatch_info)
735-
736742
out_states = self.experts(
737743
hidden_states,
738744
topk_weights,

0 commit comments

Comments
 (0)