2929
3030class MoEGateMixin :
3131 def gate_score_func (self , logits : paddle .Tensor ) -> paddle .Tensor :
32+ # [..., hidden_dim] -> [..., num_experts]
3233 with paddle .amp .auto_cast (False ):
33- # [..., hidden_dim] -> [..., num_experts]
3434 scoring_func = getattr (self , "scoring_func" , None )
3535 if scoring_func == "softmax" :
36- scores = F .softmax (logits , axis = - 1 )
36+ scores = F .softmax (logits . cast ( "float32" ) , axis = - 1 )
3737 elif scoring_func == "sigmoid" :
38- scores = F .sigmoid (logits )
38+ scores = F .sigmoid (logits . cast ( "float32" ) )
3939 elif scoring_func == "tanh" :
40- scores = F .tanh (logits )
40+ scores = F .tanh (logits . cast ( "float32" ) )
4141 elif scoring_func == "relu" :
42- scores = F .relu (logits )
42+ scores = F .relu (logits . cast ( "float32" ) )
4343 elif scoring_func == "gelu" :
44- scores = F .gelu (logits )
44+ scores = F .gelu (logits . cast ( "float32" ) )
4545 elif scoring_func == "leaky_relu" :
46- scores = F .leaky_relu (logits )
46+ scores = F .leaky_relu (logits . cast ( "float32" ) )
4747 else :
4848 logger .warning_once (
4949 f"insupportable scoring function for MoE gating: { scoring_func } , use softmax instead"
5050 )
51- scores = F .softmax (logits , axis = - 1 )
51+ scores = F .softmax (logits . cast ( "float32" ) , axis = - 1 )
5252 return scores
5353
5454 def gumbel_rsample (self , logits : paddle .Tensor ) -> paddle .Tensor :
@@ -130,9 +130,7 @@ def _cal_aux_loss(self, gates, mask):
130130 aux_loss = paddle .sum (me * ce ) * float (self .num_experts )
131131 return aux_loss
132132
133- def _cal_seq_aux_loss (self , probs , top_k , routing_map , seq_length ):
134- max_seq_len = seq_length
135-
133+ def _cal_seq_aux_loss (self , probs , top_k , routing_map , max_seq_len ):
136134 sub_max_seq_len = max_seq_len
137135 if hasattr (self , "moe_subbatch_token_num" ) and self .moe_subbatch_token_num > 0 :
138136 sub_max_seq_len = self .moe_subbatch_token_num * self .tensor_parallel_degree
@@ -162,7 +160,6 @@ def _cal_seq_aux_loss(self, probs, top_k, routing_map, seq_length):
162160 )
163161 # [B, E] -> [B] -> []
164162 seq_aux_loss = (cost_coeff * all_probs .sum (axis = seq_axis ) / max_seq_len ).sum (axis = 1 ).mean ()
165-
166163 return seq_aux_loss
167164
168165 def _cal_z_loss (self , logits ) -> paddle .Tensor :
@@ -361,6 +358,9 @@ def _topk_noaux_tc(
361358 ) # [n, e]
362359 tmp_scores = scores_for_choice * score_mask # [n, e]
363360 topk_weight , topk_idx = paddle .topk (tmp_scores , k = k , axis = - 1 , sorted = True )
361+
362+ # The bias term b is used only to adjust affinity scores for Top-K expert selection (routing); it does not affect gating.
363+ # The gate applied during dispatch and to weight the FFN output is computed from the original affinity score s_{i,t} (without the bias).
364364 topk_weight = scores .take_along_axis (topk_idx , axis = 1 ) if not self .training else topk_weight
365365
366366 return topk_weight , topk_idx
@@ -378,6 +378,13 @@ def __init__(
378378 norm_topk_prob : bool ,
379379 moe_config : Dict ,
380380 seq_length : int ,
381+ n_group : int ,
382+ topk_group : int ,
383+ routed_scaling_factor : float ,
384+ moe_subbatch_token_num : int ,
385+ tensor_parallel_degree : int ,
386+ sequence_parallel : bool ,
387+ transpose_gate_weight : bool ,
381388 ):
382389 super (StandardMoEGate , self ).__init__ ()
383390
@@ -390,8 +397,15 @@ def __init__(
390397 # force keep in float32 when using amp
391398 self ._cast_to_low_precision = False
392399 self .seq_length = seq_length
393-
394- self .scoring_func = moe_config .get ("scoring_func" , "softmax" )
400+ self .n_group = n_group
401+ self .topk_group = topk_group
402+ self .routed_scaling_factor = routed_scaling_factor
403+ self .moe_subbatch_token_num = moe_subbatch_token_num
404+ self .tensor_parallel_degree = tensor_parallel_degree
405+ self .sequence_parallel = sequence_parallel
406+ self .transpose_gate_weight = transpose_gate_weight
407+
408+ self .scoring_func = moe_config .get ("gate_activation" , "softmax" )
395409 self .capacity_factor = moe_config .get ("capacity_factor" , 1.0 )
396410 self .eval_capacity_factor = moe_config .get ("eval_capacity_factor" , 1.0 )
397411 self .min_capacity = moe_config .get ("min_capacity" , 1 )
@@ -401,26 +415,45 @@ def __init__(
401415 self .use_rts = moe_config .get ("use_rts" , True )
402416 self .top2_2nd_expert_sampling = moe_config .get ("top2_2nd_expert_sampling" , True )
403417 self .drop_policy = moe_config .get ("drop_policy" , "probs" )
404- self .n_group = moe_config .get ("n_group" , 1 ) # for group_limited_greedy
405- self .topk_group = moe_config .get ("topk_group" , 1 ) # for group_limited_greedy
406- self .routed_scaling_factor = moe_config .get ("routed_scaling_factor" , 1.0 )
407- self .seq_aux = moe_config .get ("seq_aux" , False )
418+ self .seq_aux = moe_config .get ("seq_aux" , True )
408419
409420 if self .global_aux_loss :
410421 assert self .group is not None , "group is required when global_aux_loss is True"
411422 self .rank = dist .get_rank (self .group )
412423
413- self .weight = paddle .create_parameter (
414- shape = [self .expert_hidden_size , self .num_experts ],
415- dtype = "float32" ,
416- default_initializer = paddle .nn .initializer .Uniform (),
417- )
424+ # Accordding to the shape of gate weights in model checkpoint
425+ if not transpose_gate_weight :
426+ self .weight = paddle .create_parameter (
427+ shape = [self .expert_hidden_size , self .num_experts ],
428+ dtype = "float32" ,
429+ default_initializer = paddle .nn .initializer .Uniform (),
430+ )
431+ else :
432+ self .weight = paddle .create_parameter (
433+ shape = [self .num_experts , self .expert_hidden_size ],
434+ dtype = "float32" ,
435+ default_initializer = paddle .nn .initializer .Uniform (),
436+ )
437+
438+ if self .topk_method == "noaux_tc" :
439+ self .register_buffer ("e_score_correction_bias" , paddle .zeros ((self .num_experts ,), dtype = paddle .float32 ))
440+ self ._cast_to_low_precision = False
441+ self .expert_usage = paddle .zeros (
442+ shape = [self .num_experts ],
443+ dtype = paddle .int64 ,
444+ ) # Used in MoECorrectionBiasAdjustCallback
445+ self .expert_usage .stop_gradient = True
418446
419447 def forward (
420448 self ,
421449 gates : paddle .Tensor ,
422450 ) -> Tuple [int , paddle .Tensor , paddle .Tensor , paddle .Tensor , paddle .Tensor , paddle .Tensor ]:
423- return self .topkgating (gates )
451+ capacity , top_gate , top_idx , gates_masked , mask , token_priority , l_aux , l_zloss = self .topkgating (gates )
452+ exp_counts = paddle .sum (mask .cast (paddle .int64 ), axis = 0 )
453+ if self .topk_method == "noaux_tc" :
454+ with paddle .no_grad ():
455+ self .expert_usage += exp_counts
456+ return capacity , top_gate , top_idx , gates_masked , mask , token_priority , l_aux , l_zloss
424457
425458 def topkgating (
426459 self ,
@@ -434,14 +467,19 @@ def topkgating(
434467 elif len (gates .shape ) == 2 :
435468 batch_size_seq_len , d_model = gates .shape
436469
470+ with paddle .amp .auto_cast (False ):
471+ gates = gates .cast (self .weight .dtype )
472+ if not self .transpose_gate_weight :
473+ logits = F .linear (gates .cast ("float32" ), self .weight .cast ("float32" ))
474+ else :
475+ logits = F .linear (gates .cast ("float32" ), self .weight .cast ("float32" ).t ())
476+ gates = self .gate_score_func (logits = logits )
477+ gates = gates .cast (paddle .float32 )
478+
437479 gates_ori = gates
438480 if self .scoring_func == "sigmoid" :
439481 gates_ori = gates_ori / (gates_ori .sum (axis = - 1 , keepdim = True ) + 1e-20 )
440482
441- logits = F .linear (gates , self .weight )
442-
443- gates = self .gate_score_func (logits = logits )
444-
445483 l_zloss = self ._cal_z_loss (gates )
446484
447485 if self .topk_method == "greedy" :
@@ -506,9 +544,7 @@ def topkgating(
506544 denom_s = paddle .clip (gates_s , min = paddle .finfo (gates_masked .dtype ).eps )
507545 if self .norm_topk_prob :
508546 gates_masked = gates_masked / denom_s
509- gates_masked = gates_masked .to (gates .dtype )
510547 gates_masked *= self .routed_scaling_factor
511-
512548 return (
513549 capacity , # new capacity
514550 top_gate , # weights of selected experts for each token [num_tokens, num_experts_per_token]
0 commit comments