Skip to content

Conversation

tomeras91
Copy link
Contributor

@tomeras91 tomeras91 commented Sep 29, 2025

Purpose

Add support for an MoE module in the NemotronH architecture.
This MoE module is relatively unique (to the best of my knowledge, comparable only to nomic-ai/nomic-embed-text-v2-moe), as it uses a non-gated Squared ReLU activation function.

In this PR:

  • Add an NemotronHMoE module to the NemotronH modeling file
  • Add the option to use non-gated MoE from the FusedMoE class (in addition to by calling the fused_moe function directly)
  • Add support for the Squared ReLU activation function in the MoE triton path
  • Add support for Squared ReLU non-gated FP8 MoE in ModelOptFp8MoEMethod quant_method, currently only in the triton path

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for a non-gated Squared ReLU MoE module in the NemotronH architecture, which is a valuable enhancement. The changes are mostly well-implemented across the fused MoE layers and model definition. However, I've identified a critical bug in the forward pass of the new NemotronHMoE module related to incorrect floating-point computation and a potential UnboundLocalError. I've provided a detailed comment with a suggested fix for this issue. Addressing this is crucial for the correctness of the model's output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To the reviewer(s)

NemotronHForCausalLM now optionally has an MoE block. I was wondering if it should implement the MixtureOfExperts interface or not. Do you have any guidance?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to something similar to this PR #25311 (comment), where is_mixture_of_experts depends on an attribute of the model. I don't know all the cases where this is used though

Copy link

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tomeras91.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 14, 2025
@mergify mergify bot removed the needs-rebase label Oct 15, 2025
Signed-off-by: Tomer Asida <[email protected]>
Comment on lines +1212 to +1228
if not self.moe_config.is_act_and_mul:
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8MoEMethod,
)

if not isinstance(
quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
):
raise NotImplementedError(
"is_act_and_mul=False is supported only for unquantized "
"and ModelOpt FP8 moe for now"
)
if not current_platform.is_cuda():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA for now"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the blockers for supporting is_act_and_mul = False more generally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating the relevant kernels :) We plan to follow up with that

Comment on lines +357 to +361
if (
envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
and self.moe.is_act_and_mul
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For NemotronH, self.flashinfer_moe_backend will end up being None. What implementation ends up getting used in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triton kernels. This is currently the only code path available with is_act_and_mul=False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is going to be very complicated to add to all the quant and kernel backends

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to something similar to this PR #25311 (comment), where is_mixture_of_experts depends on an attribute of the model. I don't know all the cases where this is used though

num_redundant_experts=self.n_redundant_experts,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Copy link
Member

@tlrmchlsmth tlrmchlsmth Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For DP+TP cases, we should use the sequence parallel trick like in #24982 to avoid duplicate work in the expert layers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants