diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 6d3477cd1991..200212dfb42b 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -823,6 +823,8 @@ class FusedMoEConfig: has_bias: bool = False + is_act_and_mul: bool = True + def __post_init__(self): if self.dp_size > 1: logger.debug_once( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 031381332cc9..89e92edc8d2b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1647,6 +1647,7 @@ def fused_experts( SILU_NO_MUL: str = activation_without_mul("silu") GELU_NO_MUL: str = activation_without_mul("gelu") +RELU2_NO_MUL: str = activation_without_mul("relu2") def _get_config_quant_dtype( @@ -1914,7 +1915,8 @@ def fused_experts_impl( intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == GELU_NO_MUL: intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) - + elif activation == RELU2_NO_MUL: + intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N))) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}.") diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6a61df739d14..2c86b8ed32f3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -411,11 +411,15 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + if self.moe.is_act_and_mul: + w13_up_dim = 2 * intermediate_size_per_partition + else: + w13_up_dim = intermediate_size_per_partition # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.empty( num_experts, - 2 * intermediate_size_per_partition, + w13_up_dim, hidden_size, dtype=params_dtype, ), @@ -425,9 +429,7 @@ def create_weights( set_weight_attrs(w13_weight, extra_weight_attrs) if self.moe.has_bias: w13_bias = torch.nn.Parameter( - torch.zeros( - num_experts, 2 * intermediate_size_per_partition, dtype=params_dtype - ), + torch.zeros(num_experts, w13_up_dim, dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w13_bias", w13_bias) @@ -1073,6 +1075,7 @@ def __init__( e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + is_act_and_mul: bool = True, enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, @@ -1263,6 +1266,7 @@ def __init__( in_dtype=moe_in_dtype, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, has_bias=has_bias, + is_act_and_mul=is_act_and_mul, ) self.moe_config = moe self.moe_quant_config: FusedMoEQuantConfig | None = None @@ -1283,6 +1287,24 @@ def __init__( assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method + if not self.moe_config.is_act_and_mul: + # Avoid circular import + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptFp8MoEMethod, + ) + + if not isinstance( + quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod) + ): + raise NotImplementedError( + "is_act_and_mul=False is supported only for unquantized " + "and ModelOpt FP8 moe for now" + ) + if not current_platform.is_cuda(): + raise NotImplementedError( + "is_act_and_mul=False is supported only for CUDA for now" + ) + if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod @@ -1531,7 +1553,10 @@ def _load_w13( ): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - shard_size = expert_data.shape[shard_dim] // 2 + if self.moe_config.is_act_and_mul: + shard_size = expert_data.shape[shard_dim] // 2 + else: + shard_size = expert_data.shape[shard_dim] if not load_full: loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 3eeb42d22ae0..0eeeaa3ce457 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -354,7 +354,11 @@ def __init__( self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: FlashinferMoeBackend | None = None - if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + if ( + envs.VLLM_USE_FLASHINFER_MOE_FP8 + and has_flashinfer_moe() + and self.moe.is_act_and_mul + ): self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" @@ -405,10 +409,15 @@ def create_weights( ) weight_loader = extra_weight_attrs.get("weight_loader") + if self.moe.is_act_and_mul: + w13_up_dim = 2 * intermediate_size_per_partition + else: + w13_up_dim = intermediate_size_per_partition + w13_weight = ModelWeightParameter( data=torch.empty( num_experts, - 2 * intermediate_size_per_partition, + w13_up_dim, hidden_size, dtype=weight_dtype, ), @@ -433,11 +442,16 @@ def create_weights( if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALES - Per-tensor scaling for ModelOpts - # Allocate 2 scales for w1 and w3 respectively. + # For gated MoE, allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. + # For non-gated MoE, allocate 1 scale for w13. + if self.moe.is_act_and_mul: + w13_weight_scale_shape = (num_experts, 2) + else: + w13_weight_scale_shape = (num_experts, 1) w13_weight_scale = PerTensorScaleParameter( data=torch.full( - (num_experts, 2), + w13_weight_scale_shape, 1.0, dtype=torch.float32, ), @@ -485,7 +499,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max of the w1 and w3 scales # then dequant and requant each expert. - if layer.w13_weight_scale.dim() == 2: + if ( + layer.w13_weight_scale.dim() == 2 + and layer.w13_weight_scale.shape[1] == 2 + ): + assert self.moe.is_act_and_mul, ( + "w13_weight_scale should have 2 elements per expert " + "only for gated MoE" + ) # Get the maximum scale across w1 and w3 for each expert max_w13_scales = layer.w13_weight_scale.max(dim=1).values diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 6e046c16b7ae..1bc5f5ae5419 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -673,7 +673,9 @@ def update_physical_experts_metadata( def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: - return isinstance(model, MixtureOfExperts) + return ( + isinstance(model, MixtureOfExperts) and getattr(model, "num_moe_layers", 0) > 0 + ) @runtime_checkable diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index a591f0b01c4e..f31579e5cfa8 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -18,7 +18,8 @@ # limitations under the License. """Inference-only NemotronH model.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable import torch from torch import nn @@ -26,13 +27,18 @@ from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.config.parallel import ParallelConfig +from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size +from vllm.distributed.communication_op import tensor_model_parallel_all_gather from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.activation import ReLUSquaredActivation +from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE +from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -54,6 +60,7 @@ from vllm.model_executor.models.interfaces import ( HasInnerState, IsHybrid, + MixtureOfExperts, SupportsLoRA, SupportsPP, SupportsQuant, @@ -61,9 +68,11 @@ from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, + sequence_parallel_chunk, ) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig @@ -73,28 +82,21 @@ class NemotronHMLP(nn.Module): def __init__( self, config: NemotronHConfig, - layer_idx: int, + intermediate_size: int, quant_config: QuantizationConfig | None = None, bias: bool = False, + reduce_results: bool = True, + is_sequence_parallel: bool = False, prefix: str = "", ) -> None: super().__init__() - hybrid_override_pattern = config.hybrid_override_pattern - mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1 - if isinstance(config.intermediate_size, list): - if len(config.intermediate_size) == 1: - intermediate_size = config.intermediate_size[0] - else: - intermediate_size = config.intermediate_size[mlp_index] - else: - intermediate_size = config.intermediate_size - self.up_proj = ColumnParallelLinear( input_size=config.hidden_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, + disable_tp=is_sequence_parallel, prefix=f"{prefix}.up_proj", ) self.down_proj = RowParallelLinear( @@ -102,6 +104,8 @@ def __init__( output_size=config.hidden_size, bias=bias, quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, prefix=f"{prefix}.down_proj", ) self.act_fn = ReLUSquaredActivation() @@ -113,6 +117,130 @@ def forward(self, x: torch.Tensor): return x +class NemotronHMoE(nn.Module): + def __init__( + self, + config: NemotronHConfig, + quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + params_dtype=torch.float32, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) + # Load balancing settings. + self.enable_eplb = parallel_config.enable_eplb + + self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts # noqa: E501 + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + if config.n_shared_experts is None or config.n_shared_experts == 0: + self.shared_experts = None + else: + intermediate_size = ( + config.moe_shared_expert_intermediate_size * config.n_shared_experts + ) + + self.shared_experts = NemotronHMLP( + config=config, + intermediate_size=intermediate_size, + quant_config=quant_config, + reduce_results=False, + is_sequence_parallel=self.is_sequence_parallel, + prefix=f"{prefix}.shared_experts", + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + e_score_correction_bias=self.gate.e_score_correction_bias, + activation=activation_without_mul(config.mlp_hidden_act), + is_act_and_mul=False, # non-gated MoE + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + shared_output, final_hidden_states = fused_moe_out + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + + return final_hidden_states.view(num_tokens, hidden_dim) + + class NemotronHMLPDecoderLayer(nn.Module): def __init__( self, @@ -121,20 +249,70 @@ def __init__( model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.config = config + hybrid_override_pattern = config.hybrid_override_pattern + mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1 + if isinstance(config.intermediate_size, list): + if len(config.intermediate_size) == 1: + intermediate_size = config.intermediate_size[0] + else: + intermediate_size = config.intermediate_size[mlp_index] + else: + intermediate_size = config.intermediate_size + self.mixer = NemotronHMLP( config, + intermediate_size=intermediate_size, quant_config=quant_config, bias=config.mlp_bias, prefix=f"{prefix}.mixer", - layer_idx=layer_idx, ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states) + return hidden_states, residual + + +class NemotronHMoEDecoderLayer(nn.Module): + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.mixer = NemotronHMoE( + config, + quant_config=quant_config, + parallel_config=parallel_config, + prefix=f"{prefix}.mixer", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, @@ -160,6 +338,7 @@ def __init__( model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -174,7 +353,7 @@ def __init__( n_groups=config.n_groups, num_heads=config.mamba_num_heads, head_dim=config.mamba_head_dim, - rms_norm_eps=config.rms_norm_eps, + rms_norm_eps=config.layer_norm_epsilon, activation=config.mamba_hidden_act, model_config=model_config, cache_config=cache_config, @@ -182,7 +361,7 @@ def __init__( prefix=f"{prefix}.mixer", ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, @@ -281,6 +460,7 @@ def __init__( model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -294,7 +474,7 @@ def __init__( prefix=f"{prefix}.mixer", ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, @@ -317,6 +497,7 @@ def forward( "M": NemotronHMambaDecoderLayer, "-": NemotronHMLPDecoderLayer, "*": NemotronHAttentionDecoderLayer, + "E": NemotronHMoEDecoderLayer, } @@ -329,6 +510,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config lora_config = vllm_config.lora_config self.config = config @@ -346,17 +528,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) + self.has_moe = "E" in config.hybrid_override_pattern + def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) layer_class = ALL_DECODER_LAYER_TYPES[ config.hybrid_override_pattern[layer_idx] ] return layer_class( - config, - layer_idx, - model_config, - cache_config, + config=config, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, + parallel_config=parallel_config, prefix=prefix, ) @@ -367,7 +552,7 @@ def get_layer(prefix: str): ["hidden_states", "residual"], config.hidden_size ) - self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -413,6 +598,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("qkv_proj", "v_proj", "v"), ] + if self.has_moe: + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + # - FusedMoe.w1 (aka gate_proj) should be up_proj since that's + # what the activation is applied to + # - FusedMoe.w3 (aka up_proj) should be ignored since we're + # using non-gated MoE + ckpt_gate_proj_name="up_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="", + num_experts=self.config.n_routed_experts, + num_redundant_experts=getattr(self, "num_redundant_experts", 0), + ) + else: + expert_params_mapping = [] + params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -438,16 +639,62 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # load other params else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class NemotronHForCausalLM( - nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + IsHybrid, + SupportsQuant, + MixtureOfExperts, ): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"backbone": "model"}, @@ -545,6 +792,61 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intmd_tensors = self.model.make_empty_intmd_tensors + # Set MoE hyperparameters + if self.model.has_moe: + self.expert_weights = [] + self.num_expert_groups = config.n_group + + self.moe_layers: list[SharedFusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, NemotronHMoEDecoderLayer): + # Pick last one layer since the first ones + # may be dense layers. + example_moe = layer.mixer + self.moe_layers.append(layer.mixer.experts) + + self.num_moe_layers = len(self.moe_layers) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts # noqa: E501 + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer, NemotronHMoEDecoderLayer): + moe = layer.mixer + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py index c8b6784d6a8e..68c40002098c 100644 --- a/vllm/transformers_utils/configs/nemotron_h.py +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -185,6 +185,15 @@ def __init__( mamba_proj_bias=False, mamba_chunk_size=256, rescale_prenorm_residual=True, + n_routed_experts=8, + n_shared_experts=1, + moe_intermediate_size=7688, + moe_shared_expert_intermediate_size=7688, + num_experts_per_tok=2, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + norm_topk_prob=True, **kwargs, ): self.vocab_size = vocab_size @@ -241,6 +250,15 @@ def __init__( self.mamba_proj_bias = mamba_proj_bias self.chunk_size = mamba_chunk_size self.rescale_prenorm_residual = rescale_prenorm_residual + self.n_routed_experts = n_routed_experts + self.n_shared_experts = n_shared_experts + self.moe_intermediate_size = moe_intermediate_size + self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size # noqa: E501 + self.num_experts_per_tok = num_experts_per_tok + self.routed_scaling_factor = routed_scaling_factor + self.n_group = n_group + self.topk_group = topk_group + self.norm_topk_prob = norm_topk_prob super().__init__( pad_token_id=pad_token_id, @@ -258,5 +276,7 @@ def layers_block_type(self): else "attention" if self.hybrid_override_pattern[i] == "*" else "mlp" + if self.hybrid_override_pattern[i] == "-" + else "moe" for i in range(self.num_hidden_layers) ]