Skip to content

Commit bf8047e

Browse files
[Unified MoE Layer]: Support GLM4.5 (#2842)
1 parent 0881734 commit bf8047e

File tree

7 files changed

+168
-62
lines changed

7 files changed

+168
-62
lines changed

paddleformers/nn/moe_deepep/modular_moe_layer.py

Lines changed: 55 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .moe_communication import AllToAllMoECommunication, DeepEPMoECommunication
3030
from .moe_expert import StandardMLPExpert
3131
from .moe_gate import StandardMoEGate
32+
from .moe_loss import AddAuxiliaryLoss
3233
from .moe_loss_instance import get_global_loss_registry
3334

3435
logger = logging.getLogger(__name__)
@@ -48,6 +49,7 @@ def __init__(
4849
moe_config: Dict,
4950
model_type: str,
5051
expert_class,
52+
transpose_gate_weight: bool,
5153
pretrained_config: Optional[PretrainedConfig] = None,
5254
):
5355

@@ -61,23 +63,25 @@ def __init__(
6163
self.norm_topk_prob = norm_topk_prob
6264
self.model_type = model_type
6365
self.expert_class = expert_class
66+
self.transpose_gate_weight = transpose_gate_weight
6467

6568
self.sequence_parallel = pretrained_config.get("sequence_parallel", False)
6669
self.tensor_parallel_degree = pretrained_config.get("tensor_parallel_degree", 1)
6770
self.seq_length = pretrained_config.get("seq_length", pretrained_config.get("max_seq_len", 1024))
6871
self.fuse_up_gate = pretrained_config.get("fuse_attention_ffn", False)
6972
self.ep_communication_type = pretrained_config.get("ep_communication_type", "deepep")
73+
self.n_group = pretrained_config.get("n_group", 1)
74+
self.topk_group = pretrained_config.get("topk_group", 1)
75+
self.routed_scaling_factor = pretrained_config.get("routed_scaling_factor", 1.0)
76+
self.aux_loss_alpha = pretrained_config.get("aux_loss_alpha", 0.0)
77+
self.moe_subbatch_token_num = pretrained_config.get("moe_subbatch_token_num", -1)
7078
try:
7179
moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group()
7280
except Exception:
7381
moe_group = None
7482
self.expert_parallel_degree = dist.get_world_size(moe_group) if moe_group is not None else 1
7583

76-
self.custom_gate = moe_config.get("custom_gate", None)
77-
self.custom_communication = moe_config.get("custom_communication", None)
7884
self.gate_activation = moe_config.get("gate_activation", "softmax")
79-
self.aux_loss_weight = moe_config.get("aux_loss_weight", 0.01)
80-
self.z_loss_weight = moe_config.get("z_loss_weight", 0.0)
8185
self.topk_method = (
8286
moe_config.get("train_topk_method", "greedy")
8387
if self.training
@@ -92,19 +96,23 @@ def __init__(
9296
self.loss_combiner_name = moe_config.get("loss_combiner_name", "weighted_sum")
9397

9498
self._init_expert_parallel()
95-
if self.custom_gate is not None:
96-
self.gate = self.custom_gate
97-
else:
98-
self.gate = StandardMoEGate(
99-
num_experts=self.num_experts,
100-
expert_hidden_size=self.hidden_size,
101-
drop_tokens=self.drop_tokens,
102-
topk_method=self.topk_method,
103-
num_experts_per_tok=self.num_experts_per_tok,
104-
norm_topk_prob=self.norm_topk_prob,
105-
moe_config=moe_config,
106-
seq_length=self.seq_length,
107-
)
99+
self.gate = StandardMoEGate(
100+
num_experts=self.num_experts,
101+
expert_hidden_size=self.hidden_size,
102+
drop_tokens=self.drop_tokens,
103+
topk_method=self.topk_method,
104+
num_experts_per_tok=self.num_experts_per_tok,
105+
norm_topk_prob=self.norm_topk_prob,
106+
moe_config=moe_config,
107+
seq_length=self.seq_length,
108+
n_group=self.n_group,
109+
topk_group=self.topk_group,
110+
routed_scaling_factor=self.routed_scaling_factor,
111+
moe_subbatch_token_num=self.moe_subbatch_token_num,
112+
tensor_parallel_degree=self.tensor_parallel_degree,
113+
sequence_parallel=self.sequence_parallel,
114+
transpose_gate_weight=self.transpose_gate_weight,
115+
)
108116

109117
if self.expert_class is None:
110118
self.expert_class = StandardMLPExpert
@@ -124,8 +132,14 @@ def __init__(
124132
if self.model_type == "qwen3_moe":
125133
pass
126134
elif self.model_type == "glm4_moe":
127-
pass
128-
self.experts = nn.LayerList([self.expert_class(**expert_args) for _ in range(self.num_experts)])
135+
expert_args["fuse_up_gate"] = self.fuse_up_gate
136+
137+
self.experts = nn.LayerList([])
138+
for i in range(self.num_experts):
139+
if i // self.num_experts_per_device == self.moe_rank:
140+
self.experts.append(self.expert_class(**expert_args))
141+
else:
142+
self.experts.append(None)
129143

130144
if self.expert_parallel_degree > 1:
131145
self.token_dispatcher = MoEFlexTokenDispatcher(
@@ -137,22 +151,25 @@ def __init__(
137151
shared_expert_args = {}
138152
shared_expert_args["config"] = shared_expert_pretrained_config
139153
shared_expert_args["intermediate_size"] = self.moe_intermediate_size * self.num_shared_experts
154+
# Add more arguments for different models
155+
if self.model_type == "qwen3_moe":
156+
pass
157+
elif self.model_type == "glm4_moe":
158+
shared_expert_args["fuse_up_gate"] = self.fuse_up_gate
159+
140160
if self.num_shared_experts > 0:
141161
self.shared_experts = self.expert_class(**shared_expert_args)
142162
else:
143163
self.shared_experts = None
144164

145-
if self.custom_communication is not None:
146-
self.communication = self.custom_communication
165+
if self.ep_communication_type == "deepep":
166+
self.communication = DeepEPMoECommunication()
167+
elif self.ep_communication_type == "alltoall":
168+
self.communication = AllToAllMoECommunication()
147169
else:
148-
if self.ep_communication_type == "deepep":
149-
self.communication = DeepEPMoECommunication()
150-
elif self.ep_communication_type == "alltoall":
151-
self.communication = AllToAllMoECommunication()
152-
else:
153-
raise ValueError(
154-
f"Unsupported communication type: {self.ep_communication_type}, please choose from ['deepep', 'alltoall']"
155-
)
170+
raise ValueError(
171+
f"Unsupported communication type: {self.ep_communication_type}, please choose from ['deepep', 'alltoall']"
172+
)
156173

157174
if hasattr(dist, "fleet") and dist.is_initialized() and self.expert_parallel_degree > 1:
158175
self.is_mp_moe = False
@@ -224,6 +241,9 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
224241
capacity, topk_weights, topk_indices, gates_masked, mask, priorities, aux_loss, z_loss = self.gate(
225242
hidden_states
226243
)
244+
# topk_weights, topk_indices will be used in AllToAllMoECommunication
245+
# gates_masked, mask will be used in DeepEPMoECommunication
246+
# capacity, priorities are not used currently
227247

228248
if self.expert_parallel_degree > 1:
229249
output = self._forward_with_ep_parallel(
@@ -237,16 +257,20 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
237257
reshaped_input = hidden_states
238258
output = self._forward_traditional_moe(reshaped_input, topk_indices, topk_weights)
239259

240-
output = output.reshape(orig_shape)
260+
if self.training and self.aux_loss_alpha > 0.0:
261+
aux_loss = aux_loss * self.aux_loss_alpha
262+
output = AddAuxiliaryLoss.apply(output, aux_loss)
241263

242264
if self.shared_experts is not None:
243265
shared_output = self.shared_experts(residuals)
244266
output = output + shared_output
245267

268+
output = output.reshape(orig_shape)
269+
246270
if self.expert_parallel_degree <= 1 and self.sequence_parallel:
247271
output = ScatterOp.apply(output)
248272

249-
return output, aux_loss
273+
return output
250274

251275
def _forward_traditional_moe(
252276
self, hidden_states: paddle.Tensor, selected_experts: paddle.Tensor, topk_weights: paddle.Tensor

paddleformers/nn/moe_deepep/moe_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def create_from_model_name(
2626
train_topk_method: str,
2727
inference_topk_method: str,
2828
drop_tokens: bool,
29+
transpose_gate_weight: bool,
2930
) -> ModularMoELayer:
3031
model_type = getattr(pretrained_config, "model_type", None)
3132
if model_type is None:
@@ -55,6 +56,7 @@ def create_from_model_name(
5556
moe_config=moe_config,
5657
model_type=model_type,
5758
expert_class=expert_class,
59+
transpose_gate_weight=transpose_gate_weight,
5860
pretrained_config=pretrained_config,
5961
)
6062

paddleformers/nn/moe_deepep/moe_gate.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,26 @@
2929

3030
class MoEGateMixin:
3131
def gate_score_func(self, logits: paddle.Tensor) -> paddle.Tensor:
32+
# [..., hidden_dim] -> [..., num_experts]
3233
with paddle.amp.auto_cast(False):
33-
# [..., hidden_dim] -> [..., num_experts]
3434
scoring_func = getattr(self, "scoring_func", None)
3535
if scoring_func == "softmax":
36-
scores = F.softmax(logits, axis=-1)
36+
scores = F.softmax(logits.cast("float32"), axis=-1)
3737
elif scoring_func == "sigmoid":
38-
scores = F.sigmoid(logits)
38+
scores = F.sigmoid(logits.cast("float32"))
3939
elif scoring_func == "tanh":
40-
scores = F.tanh(logits)
40+
scores = F.tanh(logits.cast("float32"))
4141
elif scoring_func == "relu":
42-
scores = F.relu(logits)
42+
scores = F.relu(logits.cast("float32"))
4343
elif scoring_func == "gelu":
44-
scores = F.gelu(logits)
44+
scores = F.gelu(logits.cast("float32"))
4545
elif scoring_func == "leaky_relu":
46-
scores = F.leaky_relu(logits)
46+
scores = F.leaky_relu(logits.cast("float32"))
4747
else:
4848
logger.warning_once(
4949
f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead"
5050
)
51-
scores = F.softmax(logits, axis=-1)
51+
scores = F.softmax(logits.cast("float32"), axis=-1)
5252
return scores
5353

5454
def gumbel_rsample(self, logits: paddle.Tensor) -> paddle.Tensor:
@@ -130,9 +130,7 @@ def _cal_aux_loss(self, gates, mask):
130130
aux_loss = paddle.sum(me * ce) * float(self.num_experts)
131131
return aux_loss
132132

133-
def _cal_seq_aux_loss(self, probs, top_k, routing_map, seq_length):
134-
max_seq_len = seq_length
135-
133+
def _cal_seq_aux_loss(self, probs, top_k, routing_map, max_seq_len):
136134
sub_max_seq_len = max_seq_len
137135
if hasattr(self, "moe_subbatch_token_num") and self.moe_subbatch_token_num > 0:
138136
sub_max_seq_len = self.moe_subbatch_token_num * self.tensor_parallel_degree
@@ -162,7 +160,6 @@ def _cal_seq_aux_loss(self, probs, top_k, routing_map, seq_length):
162160
)
163161
# [B, E] -> [B] -> []
164162
seq_aux_loss = (cost_coeff * all_probs.sum(axis=seq_axis) / max_seq_len).sum(axis=1).mean()
165-
166163
return seq_aux_loss
167164

168165
def _cal_z_loss(self, logits) -> paddle.Tensor:
@@ -361,6 +358,9 @@ def _topk_noaux_tc(
361358
) # [n, e]
362359
tmp_scores = scores_for_choice * score_mask # [n, e]
363360
topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True)
361+
362+
# The bias term b is used only to adjust affinity scores for Top-K expert selection (routing); it does not affect gating.
363+
# The gate applied during dispatch and to weight the FFN output is computed from the original affinity score s_{i,t} (without the bias).
364364
topk_weight = scores.take_along_axis(topk_idx, axis=1) if not self.training else topk_weight
365365

366366
return topk_weight, topk_idx
@@ -378,6 +378,13 @@ def __init__(
378378
norm_topk_prob: bool,
379379
moe_config: Dict,
380380
seq_length: int,
381+
n_group: int,
382+
topk_group: int,
383+
routed_scaling_factor: float,
384+
moe_subbatch_token_num: int,
385+
tensor_parallel_degree: int,
386+
sequence_parallel: bool,
387+
transpose_gate_weight: bool,
381388
):
382389
super(StandardMoEGate, self).__init__()
383390

@@ -390,8 +397,15 @@ def __init__(
390397
# force keep in float32 when using amp
391398
self._cast_to_low_precision = False
392399
self.seq_length = seq_length
393-
394-
self.scoring_func = moe_config.get("scoring_func", "softmax")
400+
self.n_group = n_group
401+
self.topk_group = topk_group
402+
self.routed_scaling_factor = routed_scaling_factor
403+
self.moe_subbatch_token_num = moe_subbatch_token_num
404+
self.tensor_parallel_degree = tensor_parallel_degree
405+
self.sequence_parallel = sequence_parallel
406+
self.transpose_gate_weight = transpose_gate_weight
407+
408+
self.scoring_func = moe_config.get("gate_activation", "softmax")
395409
self.capacity_factor = moe_config.get("capacity_factor", 1.0)
396410
self.eval_capacity_factor = moe_config.get("eval_capacity_factor", 1.0)
397411
self.min_capacity = moe_config.get("min_capacity", 1)
@@ -401,26 +415,45 @@ def __init__(
401415
self.use_rts = moe_config.get("use_rts", True)
402416
self.top2_2nd_expert_sampling = moe_config.get("top2_2nd_expert_sampling", True)
403417
self.drop_policy = moe_config.get("drop_policy", "probs")
404-
self.n_group = moe_config.get("n_group", 1) # for group_limited_greedy
405-
self.topk_group = moe_config.get("topk_group", 1) # for group_limited_greedy
406-
self.routed_scaling_factor = moe_config.get("routed_scaling_factor", 1.0)
407-
self.seq_aux = moe_config.get("seq_aux", False)
418+
self.seq_aux = moe_config.get("seq_aux", True)
408419

409420
if self.global_aux_loss:
410421
assert self.group is not None, "group is required when global_aux_loss is True"
411422
self.rank = dist.get_rank(self.group)
412423

413-
self.weight = paddle.create_parameter(
414-
shape=[self.expert_hidden_size, self.num_experts],
415-
dtype="float32",
416-
default_initializer=paddle.nn.initializer.Uniform(),
417-
)
424+
# Accordding to the shape of gate weights in model checkpoint
425+
if not transpose_gate_weight:
426+
self.weight = paddle.create_parameter(
427+
shape=[self.expert_hidden_size, self.num_experts],
428+
dtype="float32",
429+
default_initializer=paddle.nn.initializer.Uniform(),
430+
)
431+
else:
432+
self.weight = paddle.create_parameter(
433+
shape=[self.num_experts, self.expert_hidden_size],
434+
dtype="float32",
435+
default_initializer=paddle.nn.initializer.Uniform(),
436+
)
437+
438+
if self.topk_method == "noaux_tc":
439+
self.register_buffer("e_score_correction_bias", paddle.zeros((self.num_experts,), dtype=paddle.float32))
440+
self._cast_to_low_precision = False
441+
self.expert_usage = paddle.zeros(
442+
shape=[self.num_experts],
443+
dtype=paddle.int64,
444+
) # Used in MoECorrectionBiasAdjustCallback
445+
self.expert_usage.stop_gradient = True
418446

419447
def forward(
420448
self,
421449
gates: paddle.Tensor,
422450
) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
423-
return self.topkgating(gates)
451+
capacity, top_gate, top_idx, gates_masked, mask, token_priority, l_aux, l_zloss = self.topkgating(gates)
452+
exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0)
453+
if self.topk_method == "noaux_tc":
454+
with paddle.no_grad():
455+
self.expert_usage += exp_counts
456+
return capacity, top_gate, top_idx, gates_masked, mask, token_priority, l_aux, l_zloss
424457

425458
def topkgating(
426459
self,
@@ -434,14 +467,19 @@ def topkgating(
434467
elif len(gates.shape) == 2:
435468
batch_size_seq_len, d_model = gates.shape
436469

470+
with paddle.amp.auto_cast(False):
471+
gates = gates.cast(self.weight.dtype)
472+
if not self.transpose_gate_weight:
473+
logits = F.linear(gates.cast("float32"), self.weight.cast("float32"))
474+
else:
475+
logits = F.linear(gates.cast("float32"), self.weight.cast("float32").t())
476+
gates = self.gate_score_func(logits=logits)
477+
gates = gates.cast(paddle.float32)
478+
437479
gates_ori = gates
438480
if self.scoring_func == "sigmoid":
439481
gates_ori = gates_ori / (gates_ori.sum(axis=-1, keepdim=True) + 1e-20)
440482

441-
logits = F.linear(gates, self.weight)
442-
443-
gates = self.gate_score_func(logits=logits)
444-
445483
l_zloss = self._cal_z_loss(gates)
446484

447485
if self.topk_method == "greedy":
@@ -506,9 +544,7 @@ def topkgating(
506544
denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps)
507545
if self.norm_topk_prob:
508546
gates_masked = gates_masked / denom_s
509-
gates_masked = gates_masked.to(gates.dtype)
510547
gates_masked *= self.routed_scaling_factor
511-
512548
return (
513549
capacity, # new capacity
514550
top_gate, # weights of selected experts for each token [num_tokens, num_experts_per_token]

paddleformers/nn/moe_deepep/moe_loss.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,27 @@ def __call__(
5858
pass
5959

6060

61+
class AddAuxiliaryLoss(paddle.autograd.PyLayer):
62+
"""
63+
The trick function of adding auxiliary (aux) loss,
64+
which includes the gradient of the aux loss during backpropagation.
65+
"""
66+
67+
@staticmethod
68+
def forward(ctx, x, loss):
69+
assert paddle.numel(loss) == 1
70+
ctx.dtype = loss.dtype
71+
ctx.required_aux_loss = not loss.stop_gradient
72+
return x
73+
74+
@staticmethod
75+
def backward(ctx, grad_output):
76+
grad_loss = None
77+
if ctx.required_aux_loss:
78+
grad_loss = paddle.ones(1, dtype=ctx.dtype)
79+
return grad_output, grad_loss
80+
81+
6182
class LossCombiner(Protocol):
6283
def __call__(self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig]) -> paddle.Tensor:
6384
pass

0 commit comments

Comments
 (0)