Skip to content

Commit 87f1783

Browse files
[dlinfer] fix moe op for dlinfer. (#2917)
* fix fused_moe * refine code. * refine code. --------- Co-authored-by: yaofengchen <[email protected]>
1 parent 33f5b19 commit 87f1783

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

lmdeploy/pytorch/backends/dlinfer/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def forward(self,
4747
down_weights: torch.Tensor,
4848
expert_list: List[int] = None):
4949
"""forward."""
50-
return fused_moe(hidden_states, self.top_k, topk_ids, topk_weights,
51-
gate_up_weights, down_weights)
50+
return fused_moe(hidden_states, gate_up_weights, down_weights,
51+
topk_weights, topk_ids, self.top_k, self.renormalize)
5252

5353

5454
class DlinferFusedMoEBuilder(FusedMoEBuilder):

lmdeploy/pytorch/kernels/dlinfer/fused_moe.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
def fused_moe(
77
hidden_states: Tensor,
8-
top_k: int,
9-
topk_ids: Tensor,
10-
topk_weights: Tensor,
118
gate_up_weights: Tensor,
129
down_weights: Tensor,
10+
topk_weights: Tensor,
11+
topk_ids: Tensor,
12+
topk: int,
13+
renormalize: bool,
1314
):
14-
"""ascend fused moe."""
15-
return ext_ops.fused_moe(hidden_states, top_k, topk_ids, topk_weights,
16-
gate_up_weights, down_weights)
15+
"""dlinfer fused moe."""
16+
return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights,
17+
topk_weights, topk_ids, topk, renormalize)

0 commit comments

Comments
 (0)