Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,8 @@ class FusedMoEConfig:

has_bias: bool = False

is_act_and_mul: bool = True

def __post_init__(self):
if self.dp_size > 1:
logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d",
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,7 @@ def fused_experts(

SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
RELU2_NO_MUL: str = activation_without_mul("relu2")


def _get_config_quant_dtype(
Expand Down Expand Up @@ -1667,7 +1668,9 @@ def fused_experts_impl(
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == GELU_NO_MUL:
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))

elif activation == RELU2_NO_MUL:
intermediate_cache2 = torch.square(
F.relu(intermediate_cache1.view(-1, N)))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")

Expand Down
40 changes: 30 additions & 10 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,20 +362,22 @@ def select_gemm_impl(
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
w13_up_dim,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=params_dtype),
w13_bias = torch.nn.Parameter(torch.zeros(num_experts,
w13_up_dim,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
Expand Down Expand Up @@ -954,6 +956,7 @@ def __init__(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
Expand Down Expand Up @@ -1092,6 +1095,7 @@ def __init__(
in_dtype=moe_in_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
)
self.moe_config = moe
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
Expand All @@ -1107,6 +1111,19 @@ def __init__(
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method

if not self.moe_config.is_act_and_mul:
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8MoEMethod) # Avoid circular import
if not isinstance(
quant_method,
(UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)):
raise NotImplementedError(
"is_act_and_mul=False is supported only for unquantized "
"and ModelOpt FP8 moe for now")
if not current_platform.is_cuda():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA for now")

if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
Expand Down Expand Up @@ -1314,7 +1331,10 @@ def _load_w13(self,

# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
if self.moe_config.is_act_and_mul:
shard_size = expert_data.shape[shard_dim] // 2
else:
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
Expand Down
26 changes: 20 additions & 6 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def __init__(
cutlass_fp8_supported)
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe() and \
self.moe.is_act_and_mul:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
Expand Down Expand Up @@ -353,9 +354,14 @@ def create_weights(
params_dtype)
weight_loader = extra_weight_attrs.get("weight_loader")

if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition

w13_weight = ModelWeightParameter(
data=torch.empty(num_experts,
2 * intermediate_size_per_partition,
w13_up_dim,
hidden_size,
dtype=weight_dtype),
input_dim=2,
Expand All @@ -377,11 +383,16 @@ def create_weights(

if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
if self.moe.is_act_and_mul:
w13_weight_scale_shape = (num_experts, 2)
else:
w13_weight_scale_shape = (num_experts, 1)
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts, 2),
w13_weight_scale_shape,
1.0,
dtype=torch.float32,
),
Expand Down Expand Up @@ -429,7 +440,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales
# then dequant and requant each expert.
if layer.w13_weight_scale.dim() == 2:
if layer.w13_weight_scale.dim() == 2 and \
layer.w13_weight_scale.shape[1] == 2:
assert self.moe.is_act_and_mul, (
"w13_weight_scale should have 2 elements per expert "
"only for gated MoE")

# Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
Expand Down Expand Up @@ -1437,7 +1452,6 @@ def apply(
if (self.allow_flashinfer and self.flashinfer_moe_backend
== FlashinferMoeBackend.TENSORRT_LLM):
import flashinfer

from vllm.model_executor.models.llama4 import Llama4MoE

assert self.fused_experts is None
Expand Down
Loading
Loading