Skip to content

Conversation

@kevalmorabia97
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 commented Jan 21, 2026

What does this PR do?

Type of change: new example

Megatron-Bridge pruning example scripts (HF input, HF / Megatron output). Also defined some utility functions we can reuse for adding examples for quantization or other optimizations:

  • modelopt.torch.utils.plugins.mbridge.load_mbridge_model_from_hf: Load HF to MBridge with ModelOpt spec in desired TP/PP/etc configuration
  • modelopt.torch.utils.plugins.mbridge.get_hf_mbridge_calibration_loop: Create forward_loop for calibration on a HF dataset
    • Supports all datasets available in modelopt.torch.utils.dataset_utils (cnn_dailymail, nemotron-post-training-dataset-v2, etc)
    • Supports Micro Batch Size >= 1

Usage

From nvcr.io/nvidian/nemo:26.02.rc1 container (mount latest code to /opt/Megatron-Bridge and /opt/Model-Optimizer)

torchrun --nproc_per_node 2 /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
    --hf_model_name_or_path Qwen/Qwen3-8B \
    --prune_target_params 6e9 \
    --hparams_to_skip num_attention_heads \
    --output_hf_path /tmp/Qwen3-8B-Pruned-6B

Testing

  • Manually ran pruning script in nemo:25.11 container (plus modelopt and mbridge mounted to latest) for Qwen3-8B and Nemotron-Nano-9B-v2 with PP=8 and PP=4
  • Added per-PR CI/CD test for example script

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: ‼️ TODO
  • Did you update Changelog?: Yes

Additional Information

Summary by CodeRabbit

Release Notes

New Features

  • Added new Megatron-Bridge pruning example demonstrating Minitron-based model optimization with advanced pruning configurations.

Documentation

  • Updated core project documentation to highlight Megatron-Bridge as a supported optimization framework.
  • Added comprehensive example documentation for Megatron-Bridge workflows including pruning, distillation, and quantization.
  • Updated pruning guides with Megatron-Bridge integration examples and best practices.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 21, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 21, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

This PR introduces Megatron-Bridge support to ModelOpt with a new pruning example using Minitron on Qwen3-8B. Changes include new utility functions for loading HF models via Megatron-Bridge, constructing calibration loops, updating documentation, and refactoring the plugins system to prevent auto-registration of the bridge plugin.

Changes

Cohort / File(s) Summary
Configuration & Metadata
.github/CODEOWNERS, CHANGELOG.rst, README.md
Added CODEOWNERS entry for megatron_bridge examples, updated CHANGELOG with new pruning example, and replaced NVIDIA NeMo with NVIDIA Megatron-Bridge in integration target documentation.
New Megatron-Bridge Example
examples/megatron_bridge/README.md, examples/megatron_bridge/prune_minitron.py
Introduced new example directory with pruning documentation and CLI script orchestrating NAS-based pruning of Megatron-Bridge models using Minitron algorithm. Supports Qwen3-8B to 6B reduction with calibration dataset configuration, pruning modes (target params or export config), and dual-format model saving (Megatron or HF checkpoint).
Documentation Updates
examples/pruning/README.md
Updated terminology from Megatron-LM/NeMo to Megatron-LM (M-LM) and Megatron-Bridge (M-Bridge), replaced NeMo container reference (25.11 → 26.02), rewrote code examples to use Megatron-Bridge model loading utilities, and adjusted model applicability notes for pipeline parallelism.
Core Megatron-Bridge Utilities
modelopt/torch/utils/plugins/mbridge.py
New module providing three public functions: load_mbridge_model_from_hf() for instantiating Megatron-Bridge models from HF checkpoints with provider customization, get_hf_mbridge_calibration_loop() for constructing ModelOpt calibration loops, and internal _get_dataset_cfg() for dataset preparation.
Utility Modifications
modelopt/torch/utils/dataset_utils.py
Exposed get_dataset_samples() as public function (previously private _get_dataset_samples) and added to __all__ export list.
Distributed Utility Updates
modelopt/torch/utils/distributed.py
Modified local_rank() to fall back to global rank with warning instead of raising error when LOCAL_RANK environment variable is missing.
Plugin System Refactoring
modelopt/torch/utils/plugins/__init__.py
Commented out auto-import of megatron bridge plugin at module initialization while retaining other plugin imports (megatron_generate, megatron_mmlu, megatron_preprocess_data).

Sequence Diagram(s)

sequenceDiagram
    participant CLI as prune_minitron.py
    participant Bridge as Megatron-Bridge
    participant Dataset as Dataset/Calibration
    participant Pruning as ModelOpt Pruning
    participant Output as Output Format

    CLI->>CLI: Parse arguments & validate config
    CLI->>Bridge: load_mbridge_model_from_hf()
    Bridge-->>CLI: Return model, provider, unwrapped_model
    CLI->>Dataset: get_hf_mbridge_calibration_loop()
    Dataset-->>CLI: Return calibration loop closure
    CLI->>Pruning: Build NAS search space & config
    CLI->>Pruning: Execute pruning with forward_loop
    Pruning-->>CLI: Return pruned model
    CLI->>Output: Save to Megatron or HF format
    Output-->>CLI: Model checkpoint written
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add Megatron-Bridge pruning example scripts' accurately captures the main change: new example scripts for Megatron-Bridge pruning with supporting utility functions, directly reflecting the PR's primary objective.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/mbridge-pruning branch from ae6a842 to dc1cadc Compare January 21, 2026 09:58
@kevalmorabia97 kevalmorabia97 changed the base branch from main to kmorabia/minitron-auto January 21, 2026 09:59
@codecov
Copy link

codecov bot commented Jan 21, 2026

Codecov Report

❌ Patch coverage is 13.33333% with 13 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.13%. Comparing base (945ee02) to head (0217833).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/utils/dataset_utils.py 8.33% 11 Missing ⚠️
modelopt/torch/utils/distributed.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #800      +/-   ##
==========================================
- Coverage   74.18%   74.13%   -0.06%     
==========================================
  Files         192      192              
  Lines       19236    19258      +22     
==========================================
+ Hits        14271    14277       +6     
- Misses       4965     4981      +16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/mbridge-pruning branch 2 times, most recently from 2281a23 to 9c79afd Compare January 21, 2026 12:37
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/mbridge-pruning branch from 9c79afd to a65050a Compare January 21, 2026 20:36
Base automatically changed from kmorabia/minitron-auto to main January 21, 2026 22:34
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/mbridge-pruning branch from a65050a to 44920ad Compare January 22, 2026 10:40
@kevalmorabia97 kevalmorabia97 marked this pull request as ready for review January 22, 2026 10:43
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners January 22, 2026 10:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@examples/megatron_bridge/prune_minitron.py`:
- Around line 205-210: The existence checks for output files should be guarded
by whether the corresponding output path args are set: before calling
os.path.exists with args.output_megatron_path or args.output_hf_path, check that
args.output_megatron_path and args.output_hf_path are truthy respectively;
update the block around the os.path.exists checks (the conditions using
args.output_megatron_path and args.output_hf_path) so you only call
os.path.exists when the arg is not None/empty, and keep the existing
warn_rank_0(...) and return behavior unchanged.
- Around line 167-175: The defaulting logic for
args.prune_intermediate_checkpoint can point to a file under
args.output_megatron_path or args.output_hf_path which may not exist; before
calling mtp.prune (or any operation that writes that checkpoint) create the
parent directory for args.prune_intermediate_checkpoint using os.path.dirname
and os.makedirs(..., exist_ok=True). Ensure the directory creation happens right
after the block that sets args.prune_intermediate_checkpoint and before any
prune/save calls so mtp.prune won't fail due to a missing directory.

In `@examples/megatron_bridge/README.md`:
- Around line 59-62: Remove the duplicate "## Resources" heading: locate the
repeated heading string "## Resources" in the README (the two identical headings
shown in the diff) and delete the redundant one so there is only a single "##
Resources" section header remaining.

In `@modelopt/torch/utils/plugins/mbridge.py`:
- Around line 162-163: The code sets global_batch_size = micro_batch_size and
does integer division num_iters = num_samples // global_batch_size which yields
zero when num_samples < micro_batch_size; fix by using ceiling division and
guard zero samples: if num_samples <= 0 raise/return early for invalid input,
otherwise compute num_iters = max(1, (num_samples + global_batch_size - 1) //
global_batch_size) (or math.ceil(num_samples / global_batch_size)) so at least
one calibration iteration runs when num_samples > 0; update any callers
expecting num_iters and add a short unit test or assertion near these variables
(global_batch_size, micro_batch_size, num_samples, num_iters).
- Around line 118-135: _get_dataset_cfg currently passes the raw list returned
by get_dataset_samples into DatasetDict, causing a type mismatch because
DatasetDict expects datasets.Dataset instances; convert the list[str] to a
HuggingFace Dataset (e.g., datasets.Dataset.from_dict or from_list) before
constructing DatasetDict and update the process_example_fn to reference the
field name used (for example use Dataset.from_dict({"text": dataset}) and change
process_example_fn to use example["text"]), keeping references to
HFDatasetConfig, DatasetDict, get_dataset_samples, and process_example_fn in the
fix.
🧹 Nitpick comments (4)
modelopt/torch/utils/distributed.py (1)

76-81: Consider using warnings.warn with a filter to avoid repeated warnings.

The warning will fire on every call to local_rank() when LOCAL_RANK is not set. For workflows that call this function repeatedly, this could produce excessive log noise.

♻️ Suggested fix using `warnings.warn` with stacklevel and category
+import functools
+
+@functools.lru_cache(maxsize=1)
+def _warn_local_rank_fallback():
+    warn("LOCAL_RANK environment variable not found. Using global rank instead.", stacklevel=3)
+
 def local_rank() -> int:
     """Returns the local rank of the current process."""
     if "LOCAL_RANK" in os.environ:
         return int(os.environ["LOCAL_RANK"])
-    warn("LOCAL_RANK environment variable not found. Using global rank instead.")
+    _warn_local_rank_fallback()
     return rank()
modelopt/torch/utils/plugins/__init__.py (1)

28-32: LGTM with minor grammar nit.

The comment explaining why the Megatron-Bridge plugin is not pre-imported is helpful for maintainability.

✏️ Optional: Fix apostrophes in comment
-# NOTE: Dont pre-import megatron bridge plugin here to avoid circular dependency issues.
-#   We dont register anything so this isnt a problem.
+# NOTE: Don't pre-import megatron bridge plugin here to avoid circular dependency issues.
+#   We don't register anything so this isn't a problem.
modelopt/torch/utils/plugins/mbridge.py (2)

97-109: Consider handling NemotronHModelProvider explicitly and improving error messages.

  1. NemotronHModelProvider is imported (line 32) and used in get_hf_mbridge_calibration_loop (line 166), but it falls through to the else branch here. If this is intentional, consider documenting it or adding it to the type hints.

  2. The assert statements provide no context when they fail.

♻️ Suggested improvements
     if isinstance(provider, MambaModelProvider):
         provider.mamba_stack_spec = modelopt_mamba_stack_spec
     else:
+        # GPTModelProvider and NemotronHModelProvider use transformer_layer_spec
         provider.transformer_layer_spec = modelopt_transformer_layer_spec
 
     provider.finalize()
     if init_model_parallel:
         provider.initialize_model_parallel(seed=0)
 
     model = provider.provide_distributed_model(wrap_with_ddp=False)
-    assert len(model) == 1
+    assert len(model) == 1, f"Expected single model, got {len(model)} models"
     unwrapped_model = unwrap_model(model[0])
-    assert isinstance(unwrapped_model, (GPTModel, MambaModel))
+    assert isinstance(unwrapped_model, (GPTModel, MambaModel)), (
+        f"Expected GPTModel or MambaModel, got {type(unwrapped_model)}"
+    )

Also consider updating the return type annotation on line 65 to include NemotronHModelProvider if it's a supported provider type:

GPTModelProvider | MambaModelProvider | NemotronHModelProvider

211-221: Unused parameter m in forward_loop - consider prefixing with underscore.

The forward_loop function accepts parameter m but uses the outer scope model variable instead. This is likely intentional for ModelOpt API compatibility, but the unused parameter could confuse readers.

♻️ Suggested fix: prefix unused parameter
-    def forward_loop(m):
+    def forward_loop(_model):
+        # NOTE: _model parameter is unused; the Megatron model list from closure is used instead
         evaluate_and_print_results(
             state,
             prefix="iteration 1",
             forward_step_func=forward_step,
             data_iterator=train_data_iterator,
             model=model,
             config=cfg,
             verbose=True,
             write_to_tensorboard=False,
         )

@lars-reimann
Copy link

This works really nicely overall, thanks. I've got a question about the use of top_k, though: If I run

torchrun --nproc_per_node 4 /opt/TensorRT-Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
    --hf_model_name_or_path Qwen/Qwen3-8B \
    --prune_target_params 6e9 \
    --hparams_to_skip num_attention_heads \
    --output_hf_path /workspace/checkpoints/Qwen3-8B-Pruned-6B \
    --calib_dataset_name magpie \
    --top_k 1

I get the output

Using top 1 candidates from checkpoint

====================
Top 1 candidates:
	{'num_layers': 34, 'hidden_size': 3328, 'ffn_hidden_size': 11264} -> 5.99B params
	{'num_layers': 30, 'hidden_size': 3584, 'ffn_hidden_size': 11776} -> 5.99B params
	{'num_layers': 36, 'hidden_size': 3840, 'ffn_hidden_size': 8192} -> 5.98B params
	{'num_layers': 36, 'hidden_size': 3584, 'ffn_hidden_size': 9216} -> 5.98B params
	{'num_layers': 36, 'hidden_size': 3072, 'ffn_hidden_size': 11776} -> 5.97B params
	{'num_layers': 32, 'hidden_size': 3584, 'ffn_hidden_size': 10752} -> 5.96B params
	{'num_layers': 36, 'hidden_size': 3328, 'ffn_hidden_size': 10240} -> 5.92B params
	{'num_layers': 34, 'hidden_size': 3840, 'ffn_hidden_size': 8704} -> 5.91B params
	{'num_layers': 30, 'hidden_size': 4096, 'ffn_hidden_size': 9216} -> 5.90B params
	{'num_layers': 34, 'hidden_size': 3584, 'ffn_hidden_size': 9728} -> 5.89B params
====================

MMLU Pro is then run for the 10 listed configurations instead of just 1. Is this intended?

@kevalmorabia97
Copy link
Collaborator Author

kevalmorabia97 commented Jan 22, 2026

@lars-reimann thanks for being the early user of this script and providing feedback!
May I know if you re-ran pruning script with different top-k? Initially with 10 and then with 1? I think your cached pruning scores checkpoint might have been there so the top-k candidates were directly used from the ckpt instead of re-computing

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/mbridge-pruning branch from 10a12c9 to 678f5a2 Compare January 22, 2026 19:51
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/mbridge-pruning branch from 678f5a2 to 0217833 Compare January 22, 2026 19:54
@lars-reimann
Copy link

@lars-reimann thanks for being the early user of this script and providing feedback! May I know if you re-ran pruning script with different top-k? Initially with 10 and then with 1? I think your cached pruning scores checkpoint might have been there so the top-k candidates were directly used from the ckpt instead of re-computing

That was the issue, thanks. I first executed the script with the default top_k = 10, but aborted that run. Then I ran it again with top_k = 1 without changing output_hf_path.

When using an empty output directory, the top_k argument works as expected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants