-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Model] Add MoE support for NemotronH #25863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
14b2105
e5ad365
1142be2
6b77e40
7cb22e8
76a12cf
d9258af
404c4a4
7fff9a8
f93c300
0090a48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -354,7 +354,11 @@ def __init__( | |
|
||
self.cutlass_fp8_supported = cutlass_fp8_supported() | ||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None | ||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): | ||
if ( | ||
envs.VLLM_USE_FLASHINFER_MOE_FP8 | ||
and has_flashinfer_moe() | ||
and self.moe.is_act_and_mul | ||
): | ||
Comment on lines
+357
to
+361
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For NemotronH, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. triton kernels. This is currently the only code path available with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect this is going to be very complicated to add to all the quant and kernel backends There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. We can follow-up on this discussion internally |
||
self.flashinfer_moe_backend = get_flashinfer_moe_backend() | ||
logger.info_once( | ||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" | ||
|
@@ -405,10 +409,15 @@ def create_weights( | |
) | ||
weight_loader = extra_weight_attrs.get("weight_loader") | ||
|
||
if self.moe.is_act_and_mul: | ||
w13_up_dim = 2 * intermediate_size_per_partition | ||
else: | ||
w13_up_dim = intermediate_size_per_partition | ||
|
||
w13_weight = ModelWeightParameter( | ||
data=torch.empty( | ||
num_experts, | ||
2 * intermediate_size_per_partition, | ||
w13_up_dim, | ||
hidden_size, | ||
dtype=weight_dtype, | ||
), | ||
|
@@ -433,11 +442,16 @@ def create_weights( | |
|
||
if self.quant_config.is_checkpoint_fp8_serialized: | ||
# WEIGHT SCALES - Per-tensor scaling for ModelOpts | ||
# Allocate 2 scales for w1 and w3 respectively. | ||
# For gated MoE, allocate 2 scales for w1 and w3 respectively. | ||
# They will be combined to a single scale after weight loading. | ||
# For non-gated MoE, allocate 1 scale for w13. | ||
if self.moe.is_act_and_mul: | ||
w13_weight_scale_shape = (num_experts, 2) | ||
else: | ||
w13_weight_scale_shape = (num_experts, 1) | ||
w13_weight_scale = PerTensorScaleParameter( | ||
data=torch.full( | ||
(num_experts, 2), | ||
w13_weight_scale_shape, | ||
1.0, | ||
dtype=torch.float32, | ||
), | ||
|
@@ -485,7 +499,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
# Fp8 moe kernel needs single weight scale for w13 per expert. | ||
# We take the max of the w1 and w3 scales | ||
# then dequant and requant each expert. | ||
if layer.w13_weight_scale.dim() == 2: | ||
if ( | ||
layer.w13_weight_scale.dim() == 2 | ||
and layer.w13_weight_scale.shape[1] == 2 | ||
): | ||
assert self.moe.is_act_and_mul, ( | ||
"w13_weight_scale should have 2 elements per expert " | ||
"only for gated MoE" | ||
) | ||
# Get the maximum scale across w1 and w3 for each expert | ||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the blockers for supporting
is_act_and_mul = False
more generally?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating the relevant kernels :) We plan to follow up with that