@@ -577,7 +577,11 @@ def forward(
577
577
class MoEGate (nn .Module ):
578
578
"""Deepseek Gate."""
579
579
580
- def __init__ (self , config : Any , dtype : torch .dtype = None , device : torch .device = None ):
580
+ def __init__ (self ,
581
+ config : Any ,
582
+ dtype : torch .dtype = None ,
583
+ device : torch .device = None ,
584
+ info : eplb .EPLBDispatchInfo = None ):
581
585
super ().__init__ ()
582
586
self .config = config
583
587
self .top_k = config .num_experts_per_tok
@@ -602,6 +606,7 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device
602
606
self .softmax_topk = SoftmaxTopK (self .top_k )
603
607
604
608
self .fake_eplb = getenv ('LMDEPLOY_FAKE_EPLB' , 'False' ).lower () == 'true'
609
+ self .eplb_dispatch_info = info
605
610
606
611
def _compute_scores (self , logits : torch .Tensor ):
607
612
"""compute scores."""
@@ -665,6 +670,9 @@ def forward(self, hidden_states: torch.Tensor):
665
670
if not self .renormalize or self .topk_method == 'noaux_tc' :
666
671
topk_weight = topk_weight * self .routed_scaling_factor
667
672
673
+ if self .eplb_dispatch_info is not None :
674
+ topk_idx = eplb .topk_ids_logical_to_physical (topk_idx , self .eplb_dispatch_info )
675
+
668
676
return topk_weight , topk_idx
669
677
670
678
@@ -685,18 +693,19 @@ def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: to
685
693
self .n_group = config .n_group
686
694
self .topk_group = config .topk_group
687
695
688
- self .gate = MoEGate (config , dtype = dtype , device = device )
689
-
690
696
dist_ctx = get_dist_manager ().current_context ()
691
697
dp = dist_ctx .dp
692
698
world_size = dist_ctx .world_size
693
699
moe_all_reduce = dp > 1 and dist_ctx .tp > 1
694
700
if get_dist_manager ().current_context ().dist_config .enable_eplb :
695
- self . eplb_dispatch_info = eplb .EPLBDispatchInfo .init_new (
701
+ eplb_dispatch_info = eplb .EPLBDispatchInfo .init_new (
696
702
ep_rank = dist_ctx .ep_rank ,
697
703
layer_idx = layer_idx ,
698
704
)
699
705
self .num_experts = eplb .get_global_eplb_metadata ().num_physical_experts ()
706
+ self .gate = MoEGate (config , dtype = dtype , device = device , info = eplb_dispatch_info )
707
+ else :
708
+ self .gate = MoEGate (config , dtype = dtype , device = device , info = None )
700
709
self .experts = build_fused_moe (
701
710
self .hidden_dim ,
702
711
self .ffn_dim ,
@@ -730,9 +739,6 @@ def forward(self, hidden_states: torch.Tensor):
730
739
batch_size , sequence_length , hidden_dim = hidden_states .shape
731
740
hidden_states = hidden_states .view (- 1 , hidden_dim )
732
741
topk_weights , topk_ids = self .gate (hidden_states )
733
- if get_dist_manager ().current_context ().dist_config .enable_eplb :
734
- topk_ids = eplb .topk_ids_logical_to_physical (topk_ids , self .eplb_dispatch_info )
735
-
736
742
out_states = self .experts (
737
743
hidden_states ,
738
744
topk_weights ,
0 commit comments