Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 20 additions & 42 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def should_use_flashinfer_mxfp4():

if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4:
import aiter
from aiter.fused_moe import fused_topk, moe_sorting
from aiter.fused_moe import fused_moe, fused_topk, moe_sorting
from aiter.ops.shuffle import shuffle_mxfp4_weight, shuffle_mxfp4_scale

class Mxfp4Config(QuantizationConfig):
Expand Down Expand Up @@ -690,51 +690,29 @@ def apply(
token_num = x.shape[0]
BLOCKM = 16 if token_num < 2048 else 32
topk_weights, topk_ids = fused_topk(x, router_logits, top_k, True)
sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_out = moe_sorting(
topk_ids,
topk_weights,
self.num_experts,
x.shape[1],
torch.bfloat16,
BLOCKM
)
_, n1, k1 = self.w13_weight_aiter_tensor.shape
_, k2, n2 = self.w2_weight_aiter_tensor.shape
D = n2 if k2 == k1 else n2*2
cktile_moe_out1 = torch.empty((token_num, top_k, D), dtype=torch.bfloat16, device=x.device)
aiter.moe_cktile2stages_gemm1(
return fused_moe(
x,
self.w13_weight_aiter_tensor,
cktile_moe_out1,
sorted_ids,
sorted_expert_ids,
num_valid_ids,
top_k,
self.intermediate_pad // 64 * 64 * 2,
self.hidden_pad // 128 * 128, # k_pad_zeros
None, # sorted_weights
None,
self.w13_scale_aiter_tensor,
self.w13_bias_aiter_tensor,
BLOCKM, # block_size
)
aiter.moe_cktile2stages_gemm2(
cktile_moe_out1,
self.w2_weight_aiter_tensor,
moe_out,
sorted_ids,
sorted_expert_ids,
num_valid_ids,
top_k,
self.hidden_pad // 64 * 64, # n_pad_zeros
self.intermediate_pad // 128 * 128,
sorted_weights, # sorted_weights
None,
self.w2_scale_aiter_tensor,
layer.w2_bias,
BLOCKM, # block_size
topk_weights,
topk_ids,
expert_mask=None,
activation=aiter.ActivationType.Swiglu,
quant_type=aiter.QuantType.per_1x32,
doweight_stage1=False,
w1_scale=self.w13_scale_aiter_tensor,
w2_scale=self.w2_scale_aiter_tensor,
a1_scale=None,
a2_scale=None,
block_size_M=BLOCKM,
num_local_tokens=None,
moe_sorting_dispatch_policy=0,
dtype=None,
hidden_pad=self.hidden_pad,
intermediate_pad=self.intermediate_pad,
bias1=self.w13_bias_aiter_tensor,
bias2=layer.w2_bias,
)
return moe_out

from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward)
Expand Down
Loading