Skip to content

Commit c757a15

Browse files
authored
[CPU]Improve cpu fused moe perf (vllm-project#27244)
Signed-off-by: Zhang Xiangze <[email protected]>
1 parent 59a50af commit c757a15

File tree

1 file changed

+40
-12
lines changed

1 file changed

+40
-12
lines changed

vllm/model_executor/layers/fused_moe/cpu_fused_moe.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from torch.nn import functional as F
77

8+
from vllm import _custom_ops as ops
89
from vllm import envs
910

1011

@@ -237,7 +238,43 @@ def __call__(
237238

238239
class CPUFusedMOE:
239240
def __init__(self, layer: torch.nn.Module) -> None:
240-
pass
241+
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported()
242+
243+
num_experts = layer.w13_weight.size(0)
244+
has_w13_bias = hasattr(layer, "w13_bias")
245+
has_w2_bias = hasattr(layer, "w2_bias")
246+
247+
layer.gate_up_linear = []
248+
layer.down_linear = []
249+
250+
for i in range(num_experts):
251+
layer_w13_weight = layer.w13_weight[i]
252+
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
253+
layer_w2_weight = layer.w2_weight[i]
254+
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None
255+
if use_onednn_mm:
256+
gate_up_handle = ops.create_onednn_mm(layer_w13_weight.t(), 32)
257+
layer.gate_up_linear.append(
258+
lambda x, handle=gate_up_handle, bias=layer_w13_bias: ops.onednn_mm(
259+
handle, x, bias
260+
)
261+
)
262+
down_handle = ops.create_onednn_mm(layer_w2_weight.t(), 32)
263+
layer.down_linear.append(
264+
lambda x, handle=down_handle, bias=layer_w2_bias: ops.onednn_mm(
265+
handle, x, bias
266+
)
267+
)
268+
else:
269+
layer.gate_up_linear.append(
270+
lambda x, w=layer_w13_weight, b=layer_w13_bias: F.linear(x, w, b)
271+
)
272+
layer.down_linear.append(
273+
lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b)
274+
)
275+
if use_onednn_mm: # remove weight
276+
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
277+
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
241278

242279
def __call__(
243280
self,
@@ -287,28 +324,19 @@ def __call__(
287324

288325
outputs = []
289326
start_idx = 0
290-
has_w13_bias = hasattr(layer, "w13_bias")
291-
has_w2_bias = hasattr(layer, "w2_bias")
292327

293328
for i, num_tokens in enumerate(tokens_per_expert):
294329
end_idx = start_idx + num_tokens
295330
if num_tokens == 0:
296331
continue
297332
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
298333

299-
layer_w13_weight = layer.w13_weight[i]
300-
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
301-
layer_w2_weight = layer.w2_weight[i]
302-
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None
303-
304-
gate_up = F.linear(
305-
tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias
306-
)
334+
gate_up = layer.gate_up_linear[i](tokens_for_this_expert)
307335
if activation == "swigluoai":
308336
gate_up = swigluoai_and_mul(gate_up)
309337
else:
310338
gate_up = silu_and_mul(gate_up)
311-
expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias)
339+
expert_out = layer.down_linear[i](gate_up)
312340
outputs.append(expert_out)
313341
start_idx = end_idx
314342

0 commit comments

Comments
 (0)