From 14b210505a1f31e46e51087b23c912ff8527cd53 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 18 Sep 2025 14:26:27 +0300 Subject: [PATCH 1/8] Add option for MoE layers in NemotronH, with non-gated MoE with squaredReLu activation - adapt the FusedMoE object to support is_act_and_mul=False Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- .../layers/fused_moe/fused_moe.py | 5 +- vllm/model_executor/layers/fused_moe/layer.py | 36 +- vllm/model_executor/models/nemotron_h.py | 325 ++++++++++++++++-- vllm/transformers_utils/configs/nemotron_h.py | 21 +- 4 files changed, 348 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d4de3f640865..958fc38b2319 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1297,6 +1297,7 @@ def fused_experts( SILU_NO_MUL: str = activation_without_mul("silu") GELU_NO_MUL: str = activation_without_mul("gelu") +RELU2_NO_MUL: str = activation_without_mul("relu2") def _get_config_quant_dtype( @@ -1506,7 +1507,9 @@ def fused_experts_impl( intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == GELU_NO_MUL: intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) - + elif activation == RELU2_NO_MUL: + intermediate_cache2 = torch.square( + F.relu(intermediate_cache1.view(-1, N))) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}.") diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index da513d75da4d..90ede7469a9b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -314,20 +314,22 @@ def select_gemm_impl( def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + if layer.is_act_and_mul: + w13_up_dim = 2 * intermediate_size_per_partition + else: + w13_up_dim = intermediate_size_per_partition # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + w13_up_dim, + hidden_size, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) if self.moe.has_bias: - w13_bias = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - dtype=params_dtype), + w13_bias = torch.nn.Parameter(torch.zeros(num_experts, + w13_up_dim, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_bias", w13_bias) set_weight_attrs(w13_bias, extra_weight_attrs) @@ -835,6 +837,7 @@ def __init__( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + is_act_and_mul: bool = True, enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, @@ -951,6 +954,7 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation + self.is_act_and_mul = is_act_and_mul if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -987,6 +991,15 @@ def __init__( assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method + if not self.is_act_and_mul: + if not isinstance(quant_method, UnquantizedFusedMoEMethod): + raise NotImplementedError( + "is_act_and_mul=False is only supported for unquantized " + "moe for now") + if not current_platform.is_cuda(): + raise NotImplementedError( + "is_act_and_mul=False is only supported for CUDA for now") + if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import ( Fp8MoEMethod) @@ -1192,7 +1205,10 @@ def _load_w13(self, # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - shard_size = expert_data.shape[shard_dim] // 2 + if self.is_act_and_mul: + shard_size = expert_data.shape[shard_dim] // 2 + else: + shard_size = expert_data.shape[shard_dim] if not load_full: loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 1e1f0524bd06..c84766e14c32 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -17,7 +17,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only NemotronH model.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from typing import Optional import torch @@ -27,13 +28,17 @@ from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.config.parallel import ParallelConfig +from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import ReLUSquaredActivation +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( @@ -59,29 +64,22 @@ from vllm.transformers_utils.configs import NemotronHConfig from vllm.utils import LayerBlockType +from .utils import is_pp_missing_parameter + class NemotronHMLP(nn.Module): def __init__( self, config: NemotronHConfig, - layer_idx: int, + intermediate_size: int, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, + reduce_results: bool = True, prefix: str = "", ) -> None: super().__init__() - hybrid_override_pattern = config.hybrid_override_pattern - mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1 - if isinstance(config.intermediate_size, list): - if len(config.intermediate_size) == 1: - intermediate_size = config.intermediate_size[0] - else: - intermediate_size = config.intermediate_size[mlp_index] - else: - intermediate_size = config.intermediate_size - self.up_proj = ColumnParallelLinear( input_size=config.hidden_size, output_size=intermediate_size, @@ -94,6 +92,7 @@ def __init__( output_size=config.hidden_size, bias=bias, quant_config=quant_config, + reduce_results=reduce_results, prefix=f"{prefix}.down_proj", ) self.act_fn = ReLUSquaredActivation() @@ -105,6 +104,111 @@ def forward(self, x: torch.Tensor): return x +class NemotronHMoE(nn.Module): + + def __init__( + self, + config: NemotronHConfig, + quant_config: Optional[QuantizationConfig] = None, + parallel_config: Optional[ParallelConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + # Load balancing settings. + self.enable_eplb = parallel_config.enable_eplb + + self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts # noqa: E501 + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + e_score_correction_bias=self.gate.e_score_correction_bias, + activation=activation_without_mul(config.mlp_hidden_act), + is_act_and_mul=False, # non-gated MoE + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) + + if config.n_shared_experts is not None: + self.shared_experts = NemotronHMLP( + config=config, + intermediate_size=config.moe_shared_expert_intermediate_size, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + prefix=f"{prefix}.shared_experts", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + if hidden_states.dtype != torch.float16: + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + else: + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if shared_output is not None: + if hidden_states.dtype != torch.float16: + final_hidden_states = final_hidden_states + shared_output + else: + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + final_hidden_states = final_hidden_states + shared_output \ + * (1. / self.routed_scaling_factor) + + if self.tp_size > 1: + final_hidden_states = ( + self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + + return final_hidden_states.view(num_tokens, hidden_dim) + + class NemotronHMLPDecoderLayer(nn.Module): def __init__( @@ -114,20 +218,71 @@ def __init__( model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + parallel_config: Optional[ParallelConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config + hybrid_override_pattern = config.hybrid_override_pattern + mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1 + if isinstance(config.intermediate_size, list): + if len(config.intermediate_size) == 1: + intermediate_size = config.intermediate_size[0] + else: + intermediate_size = config.intermediate_size[mlp_index] + else: + intermediate_size = config.intermediate_size + self.mixer = NemotronHMLP( config, + intermediate_size=intermediate_size, quant_config=quant_config, bias=config.mlp_bias, prefix=f"{prefix}.mixer", - layer_idx=layer_idx, ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states) + return hidden_states, residual + + +class NemotronHMoEDecoderLayer(nn.Module): + + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + parallel_config: Optional[ParallelConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.mixer = NemotronHMoE( + config, + quant_config=quant_config, + parallel_config=parallel_config, + prefix=f"{prefix}.mixer", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, @@ -154,6 +309,7 @@ def __init__( model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + parallel_config: Optional[ParallelConfig] = None, prefix: str = "", ) -> None: super().__init__() @@ -168,7 +324,7 @@ def __init__( n_groups=config.n_groups, num_heads=config.mamba_num_heads, head_dim=config.mamba_head_dim, - rms_norm_eps=config.rms_norm_eps, + rms_norm_eps=config.layer_norm_epsilon, activation=config.mamba_hidden_act, model_config=model_config, cache_config=cache_config, @@ -176,7 +332,7 @@ def __init__( prefix=f"{prefix}.mixer", ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, @@ -279,6 +435,7 @@ def __init__( model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + parallel_config: Optional[ParallelConfig] = None, prefix: str = "", ) -> None: super().__init__() @@ -292,7 +449,7 @@ def __init__( prefix=f"{prefix}.mixer", ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, @@ -315,6 +472,7 @@ def forward( "M": NemotronHMambaDecoderLayer, "-": NemotronHMLPDecoderLayer, "*": NemotronHAttentionDecoderLayer, + "E": NemotronHMoEDecoderLayer, } @@ -328,6 +486,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config lora_config = vllm_config.lora_config self.config = config @@ -342,16 +501,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) + self.has_moe = "E" in config.hybrid_override_pattern + def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) layer_class = ALL_DECODER_LAYER_TYPES[ config.hybrid_override_pattern[layer_idx]] return layer_class( - config, - layer_idx, - model_config, - cache_config, + config=config, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, + parallel_config=parallel_config, prefix=prefix, ) @@ -362,7 +524,8 @@ def get_layer(prefix: str): self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size) - self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -434,6 +597,22 @@ def load_weights(self, weights: Iterable[tuple[str, ("qkv_proj", "v_proj", "v"), ] + if self.has_moe: + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + # - FusedMoe.w1 (aka gate_proj) should be up_proj since that's + # what the activation is applied to + # - FusedMoe.w3 (aka up_proj) should be ignored since we're + # using non-gated MoE + ckpt_gate_proj_name="up_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="", + num_experts=self.config.n_routed_experts, + num_redundant_experts=getattr(self, "num_redundant_experts", + 0)) + else: + expert_params_mapping = [] + params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -459,10 +638,45 @@ def load_weights(self, weights: Iterable[tuple[str, # load other params else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break + else: + if is_expert_weight: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -575,6 +789,63 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intmd_tensors = (self.model.make_empty_intmd_tensors) + # Set MoE hyperparameters + if self.model.has_moe: + self.expert_weights = [] + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, NemotronHMoEDecoderLayer): + # Pick last one layer since the first ones + # may be dense layers. + example_moe = layer.mixer + self.moe_layers.append(layer.mixer.experts) + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts # noqa: E501 + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + self.model.num_redundant_experts = self.num_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + self.model.num_redundant_experts = self.num_redundant_experts + for layer in self.model.layers: + if isinstance(layer, NemotronHMoEDecoderLayer): + moe = layer.mixer + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py index 581bed5716c1..55fba0325db0 100644 --- a/vllm/transformers_utils/configs/nemotron_h.py +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -185,6 +185,15 @@ def __init__( mamba_proj_bias=False, mamba_chunk_size=256, rescale_prenorm_residual=True, + n_routed_experts=8, + n_shared_experts=1, + moe_intermediate_size=7688, + moe_shared_expert_intermediate_size=7688, + num_experts_per_tok=2, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + norm_topk_prob=True, **kwargs, ): self.vocab_size = vocab_size @@ -241,6 +250,15 @@ def __init__( self.mamba_proj_bias = mamba_proj_bias self.chunk_size = mamba_chunk_size self.rescale_prenorm_residual = rescale_prenorm_residual + self.n_routed_experts = n_routed_experts + self.n_shared_experts = n_shared_experts + self.moe_intermediate_size = moe_intermediate_size + self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size # noqa: E501 + self.num_experts_per_tok = num_experts_per_tok + self.routed_scaling_factor = routed_scaling_factor + self.n_group = n_group + self.topk_group = topk_group + self.norm_topk_prob = norm_topk_prob super().__init__( pad_token_id=pad_token_id, @@ -254,6 +272,7 @@ def __init__( def layers_block_type(self): return [ "mamba" if self.hybrid_override_pattern[i] == "M" else - "attention" if self.hybrid_override_pattern[i] == "*" else "mlp" + "attention" if self.hybrid_override_pattern[i] == "*" else + "mlp" if self.hybrid_override_pattern[i] == "-" else "moe" for i in range(self.num_hidden_layers) ] From e5ad365b3ae64b7054d7f6080299539285a2069f Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 18 Sep 2025 16:21:38 +0300 Subject: [PATCH 2/8] Add support for non-gated moe in triton path for ModelOptFp8MoEMethod Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/layer.py | 12 ++++++++---- .../layers/quantization/modelopt.py | 19 +++++++++++++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 90ede7469a9b..c3ac4eb12140 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -34,6 +34,8 @@ RoutingSimulator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptFp8MoEMethod) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum @@ -992,13 +994,15 @@ def __init__( self.quant_method = quant_method if not self.is_act_and_mul: - if not isinstance(quant_method, UnquantizedFusedMoEMethod): + if not isinstance( + quant_method, + (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)): raise NotImplementedError( - "is_act_and_mul=False is only supported for unquantized " - "moe for now") + "is_act_and_mul=False is supported only for unquantized " + "and ModelOpt FP8 moe for now") if not current_platform.is_cuda(): raise NotImplementedError( - "is_act_and_mul=False is only supported for CUDA for now") + "is_act_and_mul=False is supported only for CUDA for now") if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import ( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 7eac40825ac3..b38ccbc890c3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -296,7 +296,8 @@ def __init__( cutlass_fp8_supported) self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None - if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe() and \ + layer.is_act_and_mul: self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" @@ -344,9 +345,14 @@ def create_weights( params_dtype) weight_loader = extra_weight_attrs.get("weight_loader") + if layer.is_act_and_mul: + w13_up_dim = 2 * intermediate_size_per_partition + else: + w13_up_dim = intermediate_size_per_partition + w13_weight = ModelWeightParameter( data=torch.empty(num_experts, - 2 * intermediate_size_per_partition, + w13_up_dim, hidden_size, dtype=weight_dtype), input_dim=2, @@ -370,9 +376,13 @@ def create_weights( # WEIGHT SCALES - Per-tensor scaling for ModelOpts # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. + if layer.is_act_and_mul: + w13_weight_scale_shape = (num_experts, 2) + else: + w13_weight_scale_shape = (num_experts, ) w13_weight_scale = PerTensorScaleParameter( data=torch.full( - (num_experts, 2), + w13_weight_scale_shape, 1.0, dtype=torch.float32, ), @@ -421,6 +431,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # We take the max of the w1 and w3 scales # then dequant and requant each expert. if layer.w13_weight_scale.dim() == 2: + assert layer.is_act_and_mul, ( + "w13_weight_scale should be 2D only for gated MoE") # Get the maximum scale across w1 and w3 for each expert max_w13_scales = layer.w13_weight_scale.max(dim=1).values @@ -1416,7 +1428,6 @@ def apply( if (self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM): import flashinfer - from vllm.model_executor.models.llama4 import Llama4MoE assert self.fused_experts is None From 6b77e40e4f5b84e023ea3689736b1e4e06e29726 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Sun, 28 Sep 2025 17:00:27 +0300 Subject: [PATCH 3/8] (1) fix weight_scale shape (2) avoid circular import Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- vllm/model_executor/layers/quantization/modelopt.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 03140ffc0a37..540367c473b2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -37,8 +37,6 @@ RoutingSimulator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.modelopt import ( - ModelOptFp8MoEMethod) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum @@ -1114,6 +1112,8 @@ def __init__( self.quant_method = quant_method if not self.is_act_and_mul: + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptFp8MoEMethod) # Avoid circular import if not isinstance( quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)): diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 7c9ab579f7dc..a7e8dd2d610f 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -383,12 +383,13 @@ def create_weights( if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALES - Per-tensor scaling for ModelOpts - # Allocate 2 scales for w1 and w3 respectively. + # For gated MoE, allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. + # For non-gated MoE, allocate 1 scale for w13. if layer.is_act_and_mul: w13_weight_scale_shape = (num_experts, 2) else: - w13_weight_scale_shape = (num_experts, ) + w13_weight_scale_shape = (num_experts, 1) w13_weight_scale = PerTensorScaleParameter( data=torch.full( w13_weight_scale_shape, @@ -439,9 +440,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max of the w1 and w3 scales # then dequant and requant each expert. - if layer.w13_weight_scale.dim() == 2: + if layer.w13_weight_scale.dim() == 2 and \ + layer.w13_weight_scale.shape[1] == 2: assert layer.is_act_and_mul, ( - "w13_weight_scale should be 2D only for gated MoE") + "w13_weight_scale should have 2 elements per expert " + "only for gated MoE") # Get the maximum scale across w1 and w3 for each expert max_w13_scales = layer.w13_weight_scale.max(dim=1).values From 7cb22e872e650218cba43586d8c099877127b172 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Sun, 28 Sep 2025 17:33:53 +0300 Subject: [PATCH 4/8] Add is_act_and_mul to FusedMoEConfig instead of keeping it directly as an attribute in FusedMoE Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/config.py | 2 ++ vllm/model_executor/layers/fused_moe/layer.py | 8 ++++---- vllm/model_executor/layers/quantization/modelopt.py | 8 ++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 34bfe1c16aac..765100568d21 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -747,6 +747,8 @@ class FusedMoEConfig: has_bias: bool = False + is_act_and_mul: bool = True + def __post_init__(self): if self.dp_size > 1: logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d", diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 540367c473b2..ecd01b11ac2c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -362,7 +362,7 @@ def select_gemm_impl( def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): - if layer.is_act_and_mul: + if self.moe.is_act_and_mul: w13_up_dim = 2 * intermediate_size_per_partition else: w13_up_dim = intermediate_size_per_partition @@ -1081,7 +1081,6 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation - self.is_act_and_mul = is_act_and_mul if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -1096,6 +1095,7 @@ def __init__( in_dtype=moe_in_dtype, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, has_bias=has_bias, + is_act_and_mul=is_act_and_mul, ) self.moe_config = moe self.moe_quant_config: Optional[FusedMoEQuantConfig] = None @@ -1111,7 +1111,7 @@ def __init__( assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method - if not self.is_act_and_mul: + if not self.moe_config.is_act_and_mul: from vllm.model_executor.layers.quantization.modelopt import ( ModelOptFp8MoEMethod) # Avoid circular import if not isinstance( @@ -1331,7 +1331,7 @@ def _load_w13(self, # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - if self.is_act_and_mul: + if self.moe_config.is_act_and_mul: shard_size = expert_data.shape[shard_dim] // 2 else: shard_size = expert_data.shape[shard_dim] diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index a7e8dd2d610f..14bc455e924a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -306,7 +306,7 @@ def __init__( self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe() and \ - layer.is_act_and_mul: + self.moe.is_act_and_mul: self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" @@ -354,7 +354,7 @@ def create_weights( params_dtype) weight_loader = extra_weight_attrs.get("weight_loader") - if layer.is_act_and_mul: + if self.moe.is_act_and_mul: w13_up_dim = 2 * intermediate_size_per_partition else: w13_up_dim = intermediate_size_per_partition @@ -386,7 +386,7 @@ def create_weights( # For gated MoE, allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. # For non-gated MoE, allocate 1 scale for w13. - if layer.is_act_and_mul: + if self.moe.is_act_and_mul: w13_weight_scale_shape = (num_experts, 2) else: w13_weight_scale_shape = (num_experts, 1) @@ -442,7 +442,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # then dequant and requant each expert. if layer.w13_weight_scale.dim() == 2 and \ layer.w13_weight_scale.shape[1] == 2: - assert layer.is_act_and_mul, ( + assert self.moe.is_act_and_mul, ( "w13_weight_scale should have 2 elements per expert " "only for gated MoE") From d9258afa45a393b378d47f513891d7c004c39d1f Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 8 Oct 2025 19:56:02 +0300 Subject: [PATCH 5/8] router logits and bias in FP32 Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/models/nemotron_h.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 3a265678d2cc..2ccb840704e8 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -135,12 +135,13 @@ def __init__( config.hidden_size, config.n_routed_experts, bias=False, + params_dtype=torch.float32, quant_config=None, prefix=f"{prefix}.gate", ) self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts) + torch.empty(config.n_routed_experts, dtype=torch.float32) ) # Load balancing settings. self.enable_eplb = parallel_config.enable_eplb @@ -190,7 +191,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) + router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) if hidden_states.dtype != torch.float16: final_hidden_states = ( From 404c4a41e01e57ca31fb826a2faa9650ff7dc605 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 8 Oct 2025 20:51:48 +0300 Subject: [PATCH 6/8] use SharedFusedMoE to overlap shared expert computation with routed experts Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/models/nemotron_h.py | 113 ++++++++++++++--------- 1 file changed, 70 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 2ccb840704e8..2e136d808688 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -48,6 +48,7 @@ MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, @@ -156,63 +157,89 @@ def __init__( self.physical_expert_start + self.n_local_physical_experts ) - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func="sigmoid", - e_score_correction_bias=self.gate.e_score_correction_bias, - activation=activation_without_mul(config.mlp_hidden_act), - is_act_and_mul=False, # non-gated MoE - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - ) + if config.n_shared_experts is None or config.n_shared_experts == 0: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + e_score_correction_bias=self.gate.e_score_correction_bias, + activation=activation_without_mul(config.mlp_hidden_act), + is_act_and_mul=False, # non-gated MoE + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) + self.shared_experts = None + else: + intermediate_size = ( + config.moe_shared_expert_intermediate_size * config.n_shared_experts + ) - if config.n_shared_experts is not None: self.shared_experts = NemotronHMLP( config=config, - intermediate_size=config.moe_shared_expert_intermediate_size, + intermediate_size=intermediate_size, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs(), + reduce_results=False, prefix=f"{prefix}.shared_experts", ) + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + e_score_correction_bias=self.gate.e_score_correction_bias, + activation=activation_without_mul(config.mlp_hidden_act), + is_act_and_mul=False, # non-gated MoE + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) - if hidden_states.dtype != torch.float16: - final_hidden_states = ( - self.experts(hidden_states=hidden_states, router_logits=router_logits) - * self.routed_scaling_factor - ) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) - if shared_output is not None: - if hidden_states.dtype != torch.float16: - final_hidden_states = final_hidden_states + shared_output - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = final_hidden_states + shared_output * ( - 1.0 / self.routed_scaling_factor - ) + shared_output = None + final_hidden_states = fused_moe_out + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( From 7fff9a83945e9f184d5f1ce085d3101620381af2 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 8 Oct 2025 21:12:29 +0300 Subject: [PATCH 7/8] fix ruff according to CI Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/layers/quantization/modelopt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 254ff2b56d95..c70421718cad 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1345,13 +1345,12 @@ def prepare_static_weights_for_trtllm_fp4_moe( intermediate_size, num_experts, ): + from flashinfer import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import ( _maybe_get_cached_w2_permute_indices, _maybe_get_cached_w3_w1_permute_indices, ) - from flashinfer import nvfp4_block_scale_interleave - """Prepare quantized weights for kernel (done offline with weights).""" epilogue_tile_m = 128 # FIXME: this depends on the kernel internals @@ -1637,6 +1636,7 @@ def apply( and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): import flashinfer + from vllm.model_executor.models.llama4 import Llama4MoE assert self.fused_experts is None From 0090a48d20fb81acf84aa0def0945663fd2d068f Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 16 Oct 2025 00:03:20 +0300 Subject: [PATCH 8/8] fix import Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/models/nemotron_h.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 61e4147f487c..224c9065db6e 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -31,7 +31,7 @@ from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.activation import ReLUSquaredActivation -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -47,7 +47,6 @@ MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead,