Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,8 @@ class FusedMoEConfig:

has_bias: bool = False

is_act_and_mul: bool = True

def __post_init__(self):
if self.dp_size > 1:
logger.debug_once(
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,7 @@ def fused_experts(

SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
RELU2_NO_MUL: str = activation_without_mul("relu2")


def _get_config_quant_dtype(
Expand Down Expand Up @@ -1888,7 +1889,8 @@ def fused_experts_impl(
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == GELU_NO_MUL:
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))

elif activation == RELU2_NO_MUL:
intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N)))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")

Expand Down
35 changes: 30 additions & 5 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,15 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
w13_up_dim,
hidden_size,
dtype=params_dtype,
),
Expand All @@ -417,9 +421,7 @@ def create_weights(
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts, 2 * intermediate_size_per_partition, dtype=params_dtype
),
torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
Expand Down Expand Up @@ -1033,6 +1035,7 @@ def __init__(
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
Expand Down Expand Up @@ -1185,6 +1188,7 @@ def __init__(
in_dtype=moe_in_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
)
self.moe_config = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
Expand All @@ -1205,6 +1209,24 @@ def __init__(
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method

if not self.moe_config.is_act_and_mul:
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8MoEMethod,
)

if not isinstance(
quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
):
raise NotImplementedError(
"is_act_and_mul=False is supported only for unquantized "
"and ModelOpt FP8 moe for now"
)
if not current_platform.is_cuda():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA for now"
)
Comment on lines +1212 to +1228
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the blockers for supporting is_act_and_mul = False more generally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating the relevant kernels :) We plan to follow up with that


if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod

Expand Down Expand Up @@ -1438,7 +1460,10 @@ def _load_w13(
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
if self.moe_config.is_act_and_mul:
shard_size = expert_data.shape[shard_dim] // 2
else:
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
Expand Down
31 changes: 26 additions & 5 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,11 @@ def __init__(

self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
if (
envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
and self.moe.is_act_and_mul
):
Comment on lines +357 to +361
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For NemotronH, self.flashinfer_moe_backend will end up being None. What implementation ends up getting used in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triton kernels. This is currently the only code path available with is_act_and_mul=False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is going to be very complicated to add to all the quant and kernel backends

self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
Expand Down Expand Up @@ -405,10 +409,15 @@ def create_weights(
)
weight_loader = extra_weight_attrs.get("weight_loader")

if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition

w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
w13_up_dim,
hidden_size,
dtype=weight_dtype,
),
Expand All @@ -433,11 +442,16 @@ def create_weights(

if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
if self.moe.is_act_and_mul:
w13_weight_scale_shape = (num_experts, 2)
else:
w13_weight_scale_shape = (num_experts, 1)
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts, 2),
w13_weight_scale_shape,
1.0,
dtype=torch.float32,
),
Expand Down Expand Up @@ -485,7 +499,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales
# then dequant and requant each expert.
if layer.w13_weight_scale.dim() == 2:
if (
layer.w13_weight_scale.dim() == 2
and layer.w13_weight_scale.shape[1] == 2
):
assert self.moe.is_act_and_mul, (
"w13_weight_scale should have 2 elements per expert "
"only for gated MoE"
)
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values

Expand Down
Loading