Skip to content

Commit 870e6a2

Browse files
Enable Yarn RoPE in minitron pruning for gpt-oss
Signed-off-by: Keval Morabia <[email protected]>
1 parent 41c66bb commit 870e6a2

File tree

8 files changed

+107
-55
lines changed

8 files changed

+107
-55
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Model Optimizer Changelog (Linux)
1010

1111
**New Features**
1212

13-
- Add MoE (e.g. Qwen3-30B-A3B) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
13+
- Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
1414
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
1515

1616
0.39 (2025-11-14)

examples/megatron-lm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
| `Qwen/Qwen3-{0.6B, 8B}` || **Online** | |||
2525
| `deepseek-ai/DeepSeek-R1` || **Online** | | | |
2626
| `meta-llama/Llama-{3.1-8B, 3.1-405B, 3.2-1B}-Instruct` || **Online** | |||
27+
| `openai/gpt-oss-{20b, 120b}` || **Online** | |||
2728

2829
## Getting Started in a Local Environment
2930

examples/pruning/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Pruning can involve removal (prune) of Linear and Conv layers, and Transformer a
66

77
This section focuses on applying Model Optimizer's state-of-the-art complementary pruning modes to enable you to search for the best subnet architecture from your provided base model:
88

9-
1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT, Mamba and Hybrid Transformer Mamba models in NVIDIA NeMo or Megatron-LM framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads and GQA query groups; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model.
9+
1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT (and later extended to Mamba, MoE, and Hybrid Transformer Mamba) models in NVIDIA Megatron-LM or NeMo framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads and GQA query groups; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model.
1010
1. FastNAS: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints.
1111
1. GradNAS: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model's linear layers and attention heads to meet the given constraints.
1212

@@ -89,7 +89,7 @@ If your model parameters are already sorted, you can skip the sorting step by se
8989

9090
| **Algorithm** | **Model** | **Pruning Constraints** |
9191
| :---: | :---: | :---: |
92-
| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid Models<sup>1</sup> | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) values |
92+
| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid LLM Models<sup>1</sup> | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) values |
9393
| FastNAS | Computer Vision models | flops, parameters |
9494
| GradNAS | HuggingFace BERT, GPT-J | flops, parameters |
9595

modelopt/torch/nas/plugins/megatron.py

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch.nn as nn
2525
import torch.nn.functional as F
2626
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
27+
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
2728
from megatron.core.models.gpt import GPTModel
2829
from megatron.core.parallel_state import (
2930
get_data_parallel_group,
@@ -154,22 +155,54 @@ def _setup(self):
154155
)
155156

156157

158+
# Embedding DynamicModule ##########################################################################
157159
@DMRegistry.register(
158-
{VocabParallelEmbedding: "megatron.core.tensor_parallel.layers.VocabParallelEmbedding"}
160+
{
161+
VocabParallelEmbedding: "megatron.core.tensor_parallel.layers.VocabParallelEmbedding",
162+
nn.Embedding: "nn.Embedding",
163+
},
159164
)
160-
class _DynamicVocabParallelEmbedding(DynamicModule):
161-
"""A VocabParallelEmbedding layer with dynamic hyperparams."""
165+
class _DynamicEmbedding(DynamicModule):
166+
"""A Embedding layer with dynamic hyperparams."""
162167

163168
def _setup(self):
164169
self._register_hparam("embedding_dim", TracedHp(list(range(1, self.embedding_dim + 1))))
165170
self._register_dynamic_attribute("weight", self._get_weight)
166171

167172
@staticmethod
168-
def _get_weight(mod: "_DynamicVocabParallelEmbedding", weight: torch.Tensor) -> torch.Tensor:
173+
def _get_weight(mod: "_DynamicEmbedding", weight: torch.Tensor) -> torch.Tensor:
169174
"""Return the weight tensor of the embedding layer."""
170175
return get_sliced_tensor(mod, weight, None, "embedding_dim")
171176

172177

178+
@DMRegistry.register(
179+
{
180+
LanguageModelEmbedding: "megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding"
181+
}
182+
)
183+
class _DynamicLanguageModelEmbedding(DynamicModule):
184+
"""A LanguageModelEmbedding layer with dynamic hyperparams."""
185+
186+
def _setup(self):
187+
DMRegistry.convert(self.word_embeddings)
188+
hp_hidden_size = self.word_embeddings.get_hparam("embedding_dim")
189+
if hasattr(self, "position_embeddings") and self.position_embeddings is not None:
190+
DMRegistry.convert(self.position_embeddings)
191+
self.position_embeddings.embedding_dim = hp_hidden_size
192+
if hasattr(self, "tokentype_embeddings") and self.tokentype_embeddings is not None:
193+
DMRegistry.convert(self.tokentype_embeddings)
194+
self.tokentype_embeddings.embedding_dim = hp_hidden_size
195+
196+
def export(self) -> torch.nn.Module:
197+
self.word_embeddings.export()
198+
if hasattr(self, "position_embeddings") and self.position_embeddings is not None:
199+
self.position_embeddings.export()
200+
if hasattr(self, "tokentype_embeddings") and self.tokentype_embeddings is not None:
201+
self.tokentype_embeddings.export()
202+
return super().export()
203+
204+
205+
# Normalization DynamicModule ######################################################################
173206
@DMRegistry.register({FusedLayerNorm: "megatron.core.fusions.fused_layer_norm.FusedLayerNorm"})
174207
class _DynamicFusedLayerNorm(_DynamicLayerNorm):
175208
"""A FusedLayerNorm layer with dynamic hyperparams."""
@@ -211,8 +244,8 @@ def _setup(self):
211244
self.hparam_name = "moe_ffn_hidden_size"
212245
else:
213246
self.hparam_name = "ffn_hidden_size"
214-
self.linear_fc1 = DMRegistry.convert(self.linear_fc1)
215-
self.linear_fc2 = DMRegistry.convert(self.linear_fc2)
247+
DMRegistry.convert(self.linear_fc1)
248+
DMRegistry.convert(self.linear_fc2)
216249

217250
ffn_hidden_size = TracedHp(list(range(1, self.config.ffn_hidden_size + 1)))
218251
fc1_output_size = (
@@ -537,7 +570,7 @@ def _setup(self):
537570
{"_setup": lambda self: None},
538571
)
539572

540-
self.core_attention = _DynamicDotProductAttention.convert(self.core_attention)
573+
_DynamicDotProductAttention.convert(self.core_attention)
541574
self.core_attention._register_dynamic_attribute(
542575
"hidden_size_per_partition",
543576
lambda mod, val: self.config.kv_channels * self.num_attention_heads_per_partition,
@@ -559,7 +592,7 @@ def _setup(self):
559592
{"_setup": lambda self: None},
560593
)
561594

562-
self.core_attention = _DynamicTEDotProductAttention.convert(self.core_attention)
595+
_DynamicTEDotProductAttention.convert(self.core_attention)
563596
self.core_attention._register_dynamic_attribute(
564597
"num_attention_heads", lambda mod, val: self.num_attention_heads_per_partition
565598
)
@@ -571,10 +604,10 @@ def _setup(self):
571604
)
572605

573606
# Convert the fused qkv and output projection linear layer to dynamic module
574-
self.linear_qkv = _DynamicQKVColumnParallelLinear.convert(
607+
_DynamicQKVColumnParallelLinear.convert(
575608
self.linear_qkv, num_heads_per_group, num_query_groups
576609
)
577-
self.linear_proj = _DynamicProjRowParallelLinear.convert(
610+
_DynamicProjRowParallelLinear.convert(
578611
self.linear_proj, num_heads_per_group, num_query_groups
579612
)
580613

@@ -670,6 +703,8 @@ def _setup(self):
670703

671704
# Register dynamic attributes
672705
self._register_dynamic_attribute("weight", self._get_router_weight)
706+
if self.config.add_bias_linear:
707+
self._register_dynamic_attribute("bias", self._get_slice_by_num_experts)
673708
if self.enable_expert_bias:
674709
self._register_dynamic_attribute(
675710
"local_tokens_per_expert", self._get_slice_by_num_experts
@@ -681,8 +716,8 @@ def _get_router_weight(mod: "_DynamicTopKRouter", weight: torch.Tensor) -> torch
681716
return get_sliced_tensor(mod, weight, "num_experts", "hidden_size")
682717

683718
@staticmethod
684-
def _get_slice_by_num_experts(mod: "_DynamicTopKRouter", bias: torch.Tensor) -> torch.Tensor:
685-
return get_sliced_tensor(mod, bias, "num_experts")
719+
def _get_slice_by_num_experts(mod: "_DynamicTopKRouter", val: torch.Tensor) -> torch.Tensor:
720+
return get_sliced_tensor(mod, val, "num_experts")
686721

687722
def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
688723
"""Set hidden_size hparam for router weights from global hidden_size hparam."""
@@ -699,10 +734,10 @@ def _setup(self):
699734
self._register_hparam("num_local_experts", num_moe_experts)
700735

701736
# Convert local_experts list and each individual expert MLP to dynamic modules
702-
self.local_experts = DynamicModuleList.convert(self.local_experts)
737+
DynamicModuleList.convert(self.local_experts)
703738
self.local_experts.depth = num_moe_experts # Reuse same hparam for depth
704739
for i in range(len(self.local_experts)):
705-
self.local_experts[i] = DMRegistry.convert(self.local_experts[i])
740+
DMRegistry.convert(self.local_experts[i])
706741

707742
# Track forward activations for importance estimation.
708743
# _activations name is needed for get_activations_and_layer_scores to save scores for re-running pruning.
@@ -777,8 +812,8 @@ def _setup(self):
777812
# Convert to dynamic modules
778813
# Reuse _DynamicSequentialMLP's num_moe_experts hparam for _DynamicTopKRouter's hparam so
779814
# importance estimator and depth hparam is retained.
780-
self.router = DMRegistry.convert(self.router)
781-
self.experts = DMRegistry.convert(self.experts)
815+
DMRegistry.convert(self.router)
816+
DMRegistry.convert(self.experts)
782817
num_moe_experts_hp = self.experts.get_hparam("num_local_experts")
783818

784819
# NOTE: Use num_moe_experts hparam name in top-level module to match TransformerConfig's name
@@ -789,7 +824,7 @@ def _setup(self):
789824
)
790825
self.router.num_experts = num_moe_experts_hp
791826
if self.use_shared_expert:
792-
self.shared_experts = DMRegistry.convert(self.shared_experts)
827+
DMRegistry.convert(self.shared_experts)
793828

794829
def forward(self, *args, **kwargs):
795830
"""Forward pass for the MoE layer."""
@@ -898,12 +933,12 @@ def _setup(self):
898933
# Convert the layernorms, self-attention, and mlp/moe layers to dynamic modules
899934
# NOTE: Mamba stack layers have either Attention or MLP, not both unlike GPT models
900935
if isinstance(self.self_attention, SelfAttention):
901-
self.input_layernorm = DMRegistry.convert(self.input_layernorm)
902-
self.self_attention = DMRegistry.convert(self.self_attention)
936+
DMRegistry.convert(self.input_layernorm)
937+
DMRegistry.convert(self.self_attention)
903938

904939
if isinstance(self.mlp, (MLP, MoELayer)):
905-
self.pre_mlp_layernorm = DMRegistry.convert(self.pre_mlp_layernorm)
906-
self.mlp = DMRegistry.convert(self.mlp)
940+
DMRegistry.convert(self.pre_mlp_layernorm)
941+
DMRegistry.convert(self.mlp)
907942

908943
# Register forward hook to collect activations for importance estimation
909944
self._setup_mixin()
@@ -1168,23 +1203,23 @@ def _setup(self):
11681203
self._register_dynamic_attribute("headdim", lambda mod, val: self.mamba_head_dim)
11691204

11701205
# Convert to dynamic modules
1171-
self.in_proj = DMRegistry.convert(self.in_proj)
1206+
DMRegistry.convert(self.in_proj)
11721207
self.in_proj.output_size = build_concat_hp(
11731208
[d_inner, d_inner, bc, mamba_num_heads]
11741209
) # z, x, B, C, dt
11751210

11761211
conv_dim = build_concat_hp([d_inner, bc]) # z, B, C
1177-
self.conv1d = DMRegistry.convert(self.conv1d)
1212+
DMRegistry.convert(self.conv1d)
11781213
self.conv1d.in_channels = conv_dim
11791214
self.conv1d.out_channels = conv_dim
11801215
ks = self.conv1d.get_hparam("kernel_size")
11811216
ks.choices = [ks.original]
11821217

11831218
if self.rmsnorm:
1184-
self.norm = DMRegistry.convert(self.norm)
1219+
DMRegistry.convert(self.norm)
11851220
self.norm.hidden_size = d_inner
11861221

1187-
self.out_proj = DMRegistry.convert(self.out_proj)
1222+
DMRegistry.convert(self.out_proj)
11881223
self.out_proj.input_size = d_inner
11891224

11901225
# Register dynamic attributes for Mamba-specific parameters
@@ -1310,8 +1345,8 @@ class _DynamicMambaLayer(DynamicModule, MambaTransformerLayerMixin):
13101345

13111346
def _setup(self):
13121347
# Convert to dynamic module
1313-
self.mixer = DMRegistry.convert(self.mixer)
1314-
self.norm = DMRegistry.convert(self.norm)
1348+
DMRegistry.convert(self.mixer)
1349+
DMRegistry.convert(self.norm)
13151350
self._setup_mixin()
13161351

13171352
def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
@@ -1336,7 +1371,7 @@ def modify(
13361371
("mamba_head_dim", mamba_head_dim_divisor),
13371372
]:
13381373
hp = self.mixer.get_hparam(hp_name)
1339-
choices = {int(make_divisible(c, divisor)) for c in hp.choices} # type: ignore[arg-type]
1374+
choices = {int(make_divisible(c, divisor)) for c in hp.choices}
13401375
hp.choices = list(set(hp.choices) & choices | {hp.original})
13411376

13421377
def export(self):
@@ -1376,24 +1411,21 @@ def _setup(self):
13761411
assert self.config.expert_model_parallel_size == 1, "Expert parallel is not supported."
13771412
assert self.pre_process == is_pipeline_first_stage()
13781413
assert self.post_process == is_pipeline_last_stage()
1379-
assert self.position_embedding_type in ["rope", "none"], (
1380-
f"Only rope position embedding is supported, got {self.position_embedding_type}."
1381-
)
13821414

13831415
# Register num_layers hparam for depth pruning
13841416
self._register_hparam("num_layers", TracedHp(list(range(1, self.config.num_layers + 1))))
13851417

13861418
# Convert layers to dynamic modules and set the shared hidden_size hparam
13871419
if is_pipeline_first_stage():
1388-
self.embedding.word_embeddings = DMRegistry.convert(self.embedding.word_embeddings)
1420+
DMRegistry.convert(self.embedding)
13891421
hidden_size = self.embedding.word_embeddings.get_hparam("embedding_dim")
13901422
else:
13911423
hidden_size = None
13921424
hidden_size = dist.broadcast(hidden_size, src=0)
13931425
self._register_hparam("hidden_size", hidden_size)
13941426

13951427
for i in range(len(self.decoder.layers)):
1396-
self.decoder.layers[i] = DMRegistry.convert(self.decoder.layers[i])
1428+
DMRegistry.convert(self.decoder.layers[i])
13971429
self.decoder.layers[i].set_hidden_size_hp(hidden_size)
13981430

13991431
# NOTE: GPTModel has final_layernorm, MambaModel has final_norm
@@ -1409,7 +1441,7 @@ def _setup(self):
14091441
DMRegistry.convert(getattr(self.decoder, self.final_norm_attr_name)),
14101442
)
14111443
getattr(self.decoder, self.final_norm_attr_name).num_features = hidden_size
1412-
self.output_layer = DMRegistry.convert(self.output_layer)
1444+
DMRegistry.convert(self.output_layer)
14131445
self.output_layer.input_size = hidden_size
14141446
self.output_layer.get_hparam("output_size").choices = [self.output_layer.output_size]
14151447

@@ -1548,7 +1580,7 @@ def export(self) -> torch.nn.Module:
15481580
handle.remove()
15491581
self._export_drop_layers()
15501582
if is_pipeline_first_stage():
1551-
self.embedding.word_embeddings.export()
1583+
self.embedding.export()
15521584
for layer in self.decoder.layers:
15531585
layer.export()
15541586
if is_pipeline_last_stage():
@@ -1587,9 +1619,7 @@ def set_activations_and_layer_scores(
15871619
rank = get_pipeline_model_parallel_rank()
15881620
pp_size = get_pipeline_model_parallel_world_size()
15891621
assert len(activations_per_rank) == pp_size, (
1590-
len(activations_per_rank),
1591-
activations_per_rank,
1592-
pp_size,
1622+
f"Expected same PP size for stored pruning scores ({len(activations_per_rank)}) as current ({pp_size})!"
15931623
)
15941624
for layer in self.decoder.layers:
15951625
layer._scores = layer_scores[layer.layer_number]
@@ -1611,14 +1641,14 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i
16111641
)
16121642

16131643
supported_model_types = tuple(SUPPORTED_MODELS.keys())
1614-
for m in model.modules():
1644+
for n, m in model.named_modules():
16151645
if isinstance(m, supported_model_types):
16161646
model = m
16171647
break
16181648
assert isinstance(model, supported_model_types), (
16191649
f"Model should have one of {supported_model_types} submodule, got {model}"
16201650
)
1621-
print_rank_0(f"Dropping layers {layers_to_drop} from {type(model)}.")
1651+
print_rank_0(f"Dropping layers {layers_to_drop} from {n} ({type(model)}).")
16221652

16231653
# get the number of layers remaining in each pp rank
16241654
layers_remaining_per_pp = torch.zeros(

tests/_test_utils/torch/megatron/models.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def get_mcore_gpt_model(
138138
ffn_hidden_size: int | None = 128,
139139
max_sequence_length: int = 16,
140140
vocab_size: int = 64,
141+
position_embedding_type: str = "rope",
141142
activation_func: str = "swiglu",
142143
normalization: str = "LayerNorm",
143144
transformer_impl: str = "modelopt" if HAS_TE else "local",
@@ -191,9 +192,21 @@ def squared_relu(x):
191192
moe_router_dtype="fp32",
192193
moe_ffn_hidden_size=moe_ffn_hidden_size,
193194
moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size,
195+
moe_router_enable_expert_bias=True,
196+
moe_router_score_function="sigmoid",
194197
num_moe_experts=num_moe_experts,
195198
)
196199

200+
if position_embedding_type == "yarn": # gpt-oss like model
201+
warn("Yarn RoPE config format will change soon. This is a temporary workaround")
202+
config.yarn_rotary_scaling_factor = 32.0
203+
config.yarn_original_max_position_embeddings = 4096
204+
config.yarn_beta_fast = 32.0
205+
config.yarn_beta_slow = 1.0
206+
config.yarn_mscale = 1.0
207+
config.yarn_mscale_all_dim = 0.0
208+
config.yarn_correction_range_round_to_int = False
209+
197210
if transformer_impl == "local":
198211
assert HAS_APEX, "Apex not installed"
199212
transformer_layer_spec = get_gpt_layer_local_spec(
@@ -224,7 +237,7 @@ def squared_relu(x):
224237
pre_process=is_pipeline_first_stage(),
225238
post_process=is_pipeline_last_stage(),
226239
share_embeddings_and_output_weights=False,
227-
position_embedding_type="rope",
240+
position_embedding_type=position_embedding_type,
228241
)
229242
return model.to(torch.bfloat16) if bf16 else model
230243

0 commit comments

Comments
 (0)