-
Notifications
You must be signed in to change notification settings - Fork 243
Add Megatron-Bridge pruning example scripts #800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughThis PR introduces Megatron-Bridge support to ModelOpt with a new pruning example using Minitron on Qwen3-8B. Changes include new utility functions for loading HF models via Megatron-Bridge, constructing calibration loops, updating documentation, and refactoring the plugins system to prevent auto-registration of the bridge plugin. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as prune_minitron.py
participant Bridge as Megatron-Bridge
participant Dataset as Dataset/Calibration
participant Pruning as ModelOpt Pruning
participant Output as Output Format
CLI->>CLI: Parse arguments & validate config
CLI->>Bridge: load_mbridge_model_from_hf()
Bridge-->>CLI: Return model, provider, unwrapped_model
CLI->>Dataset: get_hf_mbridge_calibration_loop()
Dataset-->>CLI: Return calibration loop closure
CLI->>Pruning: Build NAS search space & config
CLI->>Pruning: Execute pruning with forward_loop
Pruning-->>CLI: Return pruned model
CLI->>Output: Save to Megatron or HF format
Output-->>CLI: Model checkpoint written
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
ae6a842 to
dc1cadc
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #800 +/- ##
==========================================
- Coverage 74.18% 74.13% -0.06%
==========================================
Files 192 192
Lines 19236 19258 +22
==========================================
+ Hits 14271 14277 +6
- Misses 4965 4981 +16 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
2281a23 to
9c79afd
Compare
9c79afd to
a65050a
Compare
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
a65050a to
44920ad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@examples/megatron_bridge/prune_minitron.py`:
- Around line 205-210: The existence checks for output files should be guarded
by whether the corresponding output path args are set: before calling
os.path.exists with args.output_megatron_path or args.output_hf_path, check that
args.output_megatron_path and args.output_hf_path are truthy respectively;
update the block around the os.path.exists checks (the conditions using
args.output_megatron_path and args.output_hf_path) so you only call
os.path.exists when the arg is not None/empty, and keep the existing
warn_rank_0(...) and return behavior unchanged.
- Around line 167-175: The defaulting logic for
args.prune_intermediate_checkpoint can point to a file under
args.output_megatron_path or args.output_hf_path which may not exist; before
calling mtp.prune (or any operation that writes that checkpoint) create the
parent directory for args.prune_intermediate_checkpoint using os.path.dirname
and os.makedirs(..., exist_ok=True). Ensure the directory creation happens right
after the block that sets args.prune_intermediate_checkpoint and before any
prune/save calls so mtp.prune won't fail due to a missing directory.
In `@examples/megatron_bridge/README.md`:
- Around line 59-62: Remove the duplicate "## Resources" heading: locate the
repeated heading string "## Resources" in the README (the two identical headings
shown in the diff) and delete the redundant one so there is only a single "##
Resources" section header remaining.
In `@modelopt/torch/utils/plugins/mbridge.py`:
- Around line 162-163: The code sets global_batch_size = micro_batch_size and
does integer division num_iters = num_samples // global_batch_size which yields
zero when num_samples < micro_batch_size; fix by using ceiling division and
guard zero samples: if num_samples <= 0 raise/return early for invalid input,
otherwise compute num_iters = max(1, (num_samples + global_batch_size - 1) //
global_batch_size) (or math.ceil(num_samples / global_batch_size)) so at least
one calibration iteration runs when num_samples > 0; update any callers
expecting num_iters and add a short unit test or assertion near these variables
(global_batch_size, micro_batch_size, num_samples, num_iters).
- Around line 118-135: _get_dataset_cfg currently passes the raw list returned
by get_dataset_samples into DatasetDict, causing a type mismatch because
DatasetDict expects datasets.Dataset instances; convert the list[str] to a
HuggingFace Dataset (e.g., datasets.Dataset.from_dict or from_list) before
constructing DatasetDict and update the process_example_fn to reference the
field name used (for example use Dataset.from_dict({"text": dataset}) and change
process_example_fn to use example["text"]), keeping references to
HFDatasetConfig, DatasetDict, get_dataset_samples, and process_example_fn in the
fix.
🧹 Nitpick comments (4)
modelopt/torch/utils/distributed.py (1)
76-81: Consider usingwarnings.warnwith a filter to avoid repeated warnings.The warning will fire on every call to
local_rank()whenLOCAL_RANKis not set. For workflows that call this function repeatedly, this could produce excessive log noise.♻️ Suggested fix using `warnings.warn` with stacklevel and category
+import functools + +@functools.lru_cache(maxsize=1) +def _warn_local_rank_fallback(): + warn("LOCAL_RANK environment variable not found. Using global rank instead.", stacklevel=3) + def local_rank() -> int: """Returns the local rank of the current process.""" if "LOCAL_RANK" in os.environ: return int(os.environ["LOCAL_RANK"]) - warn("LOCAL_RANK environment variable not found. Using global rank instead.") + _warn_local_rank_fallback() return rank()modelopt/torch/utils/plugins/__init__.py (1)
28-32: LGTM with minor grammar nit.The comment explaining why the Megatron-Bridge plugin is not pre-imported is helpful for maintainability.
✏️ Optional: Fix apostrophes in comment
-# NOTE: Dont pre-import megatron bridge plugin here to avoid circular dependency issues. -# We dont register anything so this isnt a problem. +# NOTE: Don't pre-import megatron bridge plugin here to avoid circular dependency issues. +# We don't register anything so this isn't a problem.modelopt/torch/utils/plugins/mbridge.py (2)
97-109: Consider handlingNemotronHModelProviderexplicitly and improving error messages.
NemotronHModelProvideris imported (line 32) and used inget_hf_mbridge_calibration_loop(line 166), but it falls through to theelsebranch here. If this is intentional, consider documenting it or adding it to the type hints.The
assertstatements provide no context when they fail.♻️ Suggested improvements
if isinstance(provider, MambaModelProvider): provider.mamba_stack_spec = modelopt_mamba_stack_spec else: + # GPTModelProvider and NemotronHModelProvider use transformer_layer_spec provider.transformer_layer_spec = modelopt_transformer_layer_spec provider.finalize() if init_model_parallel: provider.initialize_model_parallel(seed=0) model = provider.provide_distributed_model(wrap_with_ddp=False) - assert len(model) == 1 + assert len(model) == 1, f"Expected single model, got {len(model)} models" unwrapped_model = unwrap_model(model[0]) - assert isinstance(unwrapped_model, (GPTModel, MambaModel)) + assert isinstance(unwrapped_model, (GPTModel, MambaModel)), ( + f"Expected GPTModel or MambaModel, got {type(unwrapped_model)}" + )Also consider updating the return type annotation on line 65 to include
NemotronHModelProviderif it's a supported provider type:GPTModelProvider | MambaModelProvider | NemotronHModelProvider
211-221: Unused parameterminforward_loop- consider prefixing with underscore.The
forward_loopfunction accepts parametermbut uses the outer scopemodelvariable instead. This is likely intentional for ModelOpt API compatibility, but the unused parameter could confuse readers.♻️ Suggested fix: prefix unused parameter
- def forward_loop(m): + def forward_loop(_model): + # NOTE: _model parameter is unused; the Megatron model list from closure is used instead evaluate_and_print_results( state, prefix="iteration 1", forward_step_func=forward_step, data_iterator=train_data_iterator, model=model, config=cfg, verbose=True, write_to_tensorboard=False, )
|
This works really nicely overall, thanks. I've got a question about the use of torchrun --nproc_per_node 4 /opt/TensorRT-Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
--hf_model_name_or_path Qwen/Qwen3-8B \
--prune_target_params 6e9 \
--hparams_to_skip num_attention_heads \
--output_hf_path /workspace/checkpoints/Qwen3-8B-Pruned-6B \
--calib_dataset_name magpie \
--top_k 1I get the output Using top 1 candidates from checkpoint
====================
Top 1 candidates:
{'num_layers': 34, 'hidden_size': 3328, 'ffn_hidden_size': 11264} -> 5.99B params
{'num_layers': 30, 'hidden_size': 3584, 'ffn_hidden_size': 11776} -> 5.99B params
{'num_layers': 36, 'hidden_size': 3840, 'ffn_hidden_size': 8192} -> 5.98B params
{'num_layers': 36, 'hidden_size': 3584, 'ffn_hidden_size': 9216} -> 5.98B params
{'num_layers': 36, 'hidden_size': 3072, 'ffn_hidden_size': 11776} -> 5.97B params
{'num_layers': 32, 'hidden_size': 3584, 'ffn_hidden_size': 10752} -> 5.96B params
{'num_layers': 36, 'hidden_size': 3328, 'ffn_hidden_size': 10240} -> 5.92B params
{'num_layers': 34, 'hidden_size': 3840, 'ffn_hidden_size': 8704} -> 5.91B params
{'num_layers': 30, 'hidden_size': 4096, 'ffn_hidden_size': 9216} -> 5.90B params
{'num_layers': 34, 'hidden_size': 3584, 'ffn_hidden_size': 9728} -> 5.89B params
====================MMLU Pro is then run for the 10 listed configurations instead of just 1. Is this intended? |
|
@lars-reimann thanks for being the early user of this script and providing feedback! |
10a12c9 to
678f5a2
Compare
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
678f5a2 to
0217833
Compare
That was the issue, thanks. I first executed the script with the default When using an empty output directory, the |
What does this PR do?
Type of change: new example
Megatron-Bridge pruning example scripts (HF input, HF / Megatron output). Also defined some utility functions we can reuse for adding examples for quantization or other optimizations:
modelopt.torch.utils.plugins.mbridge.load_mbridge_model_from_hf: Load HF to MBridge with ModelOpt spec in desired TP/PP/etc configurationmodelopt.torch.utils.plugins.mbridge.get_hf_mbridge_calibration_loop: Createforward_loopfor calibration on a HF datasetmodelopt.torch.utils.dataset_utils(cnn_dailymail,nemotron-post-training-dataset-v2, etc)Usage
From
nvcr.io/nvidian/nemo:26.02.rc1container (mount latest code to/opt/Megatron-Bridgeand/opt/Model-Optimizer)Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.