|
5 | 5 | import torch |
6 | 6 | from torch.nn import functional as F |
7 | 7 |
|
| 8 | +from vllm import _custom_ops as ops |
8 | 9 | from vllm import envs |
9 | 10 |
|
10 | 11 |
|
@@ -237,7 +238,43 @@ def __call__( |
237 | 238 |
|
238 | 239 | class CPUFusedMOE: |
239 | 240 | def __init__(self, layer: torch.nn.Module) -> None: |
240 | | - pass |
| 241 | + use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported() |
| 242 | + |
| 243 | + num_experts = layer.w13_weight.size(0) |
| 244 | + has_w13_bias = hasattr(layer, "w13_bias") |
| 245 | + has_w2_bias = hasattr(layer, "w2_bias") |
| 246 | + |
| 247 | + layer.gate_up_linear = [] |
| 248 | + layer.down_linear = [] |
| 249 | + |
| 250 | + for i in range(num_experts): |
| 251 | + layer_w13_weight = layer.w13_weight[i] |
| 252 | + layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None |
| 253 | + layer_w2_weight = layer.w2_weight[i] |
| 254 | + layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None |
| 255 | + if use_onednn_mm: |
| 256 | + gate_up_handle = ops.create_onednn_mm(layer_w13_weight.t(), 32) |
| 257 | + layer.gate_up_linear.append( |
| 258 | + lambda x, handle=gate_up_handle, bias=layer_w13_bias: ops.onednn_mm( |
| 259 | + handle, x, bias |
| 260 | + ) |
| 261 | + ) |
| 262 | + down_handle = ops.create_onednn_mm(layer_w2_weight.t(), 32) |
| 263 | + layer.down_linear.append( |
| 264 | + lambda x, handle=down_handle, bias=layer_w2_bias: ops.onednn_mm( |
| 265 | + handle, x, bias |
| 266 | + ) |
| 267 | + ) |
| 268 | + else: |
| 269 | + layer.gate_up_linear.append( |
| 270 | + lambda x, w=layer_w13_weight, b=layer_w13_bias: F.linear(x, w, b) |
| 271 | + ) |
| 272 | + layer.down_linear.append( |
| 273 | + lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b) |
| 274 | + ) |
| 275 | + if use_onednn_mm: # remove weight |
| 276 | + layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) |
| 277 | + layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) |
241 | 278 |
|
242 | 279 | def __call__( |
243 | 280 | self, |
@@ -287,28 +324,19 @@ def __call__( |
287 | 324 |
|
288 | 325 | outputs = [] |
289 | 326 | start_idx = 0 |
290 | | - has_w13_bias = hasattr(layer, "w13_bias") |
291 | | - has_w2_bias = hasattr(layer, "w2_bias") |
292 | 327 |
|
293 | 328 | for i, num_tokens in enumerate(tokens_per_expert): |
294 | 329 | end_idx = start_idx + num_tokens |
295 | 330 | if num_tokens == 0: |
296 | 331 | continue |
297 | 332 | tokens_for_this_expert = sorted_tokens[start_idx:end_idx] |
298 | 333 |
|
299 | | - layer_w13_weight = layer.w13_weight[i] |
300 | | - layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None |
301 | | - layer_w2_weight = layer.w2_weight[i] |
302 | | - layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None |
303 | | - |
304 | | - gate_up = F.linear( |
305 | | - tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias |
306 | | - ) |
| 334 | + gate_up = layer.gate_up_linear[i](tokens_for_this_expert) |
307 | 335 | if activation == "swigluoai": |
308 | 336 | gate_up = swigluoai_and_mul(gate_up) |
309 | 337 | else: |
310 | 338 | gate_up = silu_and_mul(gate_up) |
311 | | - expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias) |
| 339 | + expert_out = layer.down_linear[i](gate_up) |
312 | 340 | outputs.append(expert_out) |
313 | 341 | start_idx = end_idx |
314 | 342 |
|
|
0 commit comments