From 44920ada802b2d08811d51e12e7aa8977d63df11 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 21 Jan 2026 01:45:47 -0800 Subject: [PATCH 1/2] Add Megatron-Bridge pruning example scripts Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/CODEOWNERS | 1 + CHANGELOG.rst | 1 + README.md | 2 +- examples/megatron_bridge/README.md | 67 ++++ examples/megatron_bridge/prune_minitron.py | 372 +++++++++++++++++++++ examples/pruning/README.md | 76 +++-- modelopt/torch/utils/dataset_utils.py | 5 +- modelopt/torch/utils/distributed.py | 4 +- modelopt/torch/utils/plugins/__init__.py | 5 + modelopt/torch/utils/plugins/mbridge.py | 223 ++++++++++++ 10 files changed, 717 insertions(+), 39 deletions(-) create mode 100644 examples/megatron_bridge/README.md create mode 100644 examples/megatron_bridge/prune_minitron.py create mode 100644 modelopt/torch/utils/plugins/mbridge.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a6b420b92..f97f2cbf5 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -44,6 +44,7 @@ modelopt/torch/utils @NVIDIA/modelopt-torch-utils-codeowners /examples/llm_ptq @NVIDIA/modelopt-examples-llm_ptq-codeowners /examples/llm_qat @NVIDIA/modelopt-examples-llm_qat-codeowners /examples/llm_sparsity @NVIDIA/modelopt-torch-sparsity-codeowners +/examples/megatron_bridge @NVIDIA/modelopt-examples-megatron-codeowners /examples/model_hub @NVIDIA/modelopt-examples-model_hub-codeowners /examples/nemo_run @NVIDIA/modelopt-examples-megatron-codeowners /examples/onnx_ptq @NVIDIA/modelopt-onnx-codeowners diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 452f36538..bcc75a3d7 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead. - Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint. - Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md `_ for more details on its usage. +- New example for Minitron pruning with Megatron-Bridge framework along with advanced pruning usage with new ``params`` constraint based pruning. Check `examples/megatron_bridge/README.md `_ for example scripts. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/README.md b/README.md index 14e91fd8b..b3282067e 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ ______________________________________________________________________ **[Input]** Model Optimizer currently supports inputs of a [Hugging Face](https://huggingface.co/), [PyTorch](https://github.com/pytorch/pytorch) or [ONNX](https://github.com/onnx/onnx) model. **[Optimize]** Model Optimizer provides Python APIs for users to easily compose the above model optimization techniques and export an optimized quantized checkpoint. -Model Optimizer is also integrated with [NVIDIA NeMo](https://github.com/NVIDIA-NeMo/NeMo), [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) and [Hugging Face Accelerate](https://github.com/huggingface/accelerate) for training required inference optimization techniques. +Model Optimizer is also integrated with [NVIDIA Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge), [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) and [Hugging Face Accelerate](https://github.com/huggingface/accelerate) for training required inference optimization techniques. **[Export for deployment]** Seamlessly integrated within the NVIDIA AI software ecosystem, the quantized checkpoint generated from Model Optimizer is ready for deployment in downstream inference frameworks like [SGLang](https://github.com/sgl-project/sglang), [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/quantization), [TensorRT](https://github.com/NVIDIA/TensorRT), or [vLLM](https://github.com/vllm-project/vllm). diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md new file mode 100644 index 000000000..fcae037c6 --- /dev/null +++ b/examples/megatron_bridge/README.md @@ -0,0 +1,67 @@ +# Megatron Bridge + +This directory contains examples of using Model Optimizer with [NeMo Megatron-Bridge](https://github.com/NVIDIA-Nemo/Megatron-Bridge) framework for pruning, distillation, quantization, etc. + +
+ +| **Section** | **Description** | **Link** | **Docs** | +| :------------: | :------------: | :------------: | :------------: | +| Pre-Requisites | Development environment setup | \[[Link](#pre-requisites)\] | | +| Pruning | Examples of pruning a model using Minitron algorithm | \[[Link](#pruning)\] | | +| Distillation | Examples of distillation a pruned or quantized model | \[[Link](#distillation)\] | | +| Quantization | Examples of quantizing a model | \[[Link](#quantization)\] | | +| Resources | Extra links to relevant resources | \[[Link](#resources)\] | | + +
+ +## Pre-Requisites + +Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02`) which has all the dependencies installed. + +To get the latest ModelOpt features and examples, you can mount your latest ModelOpt cloned repository to the container at `/opt/Model-Optimizer` or pull the latest changes once inside the docker container (`cd /opt/Model-Optimizer && git checkout main && git pull`). + +## Pruning + +This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md). + +Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults: + 1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration, + at-most 20% depth (`num_layers`) and 40% width is pruned per prunable hparam (`hidden_size`, `ffn_hidden_size`, ...), + top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model. + +```bash +torchrun --nproc_per_node 2 /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --prune_target_params 6e9 \ + --hparams_to_skip num_attention_heads \ + --output_hf_path /tmp/Qwen3-8B-Pruned-6B +``` + +To see the full usage for advanced configurations, run: + +```bash +torchrun /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help +``` + +> [!TIP] +> If number of layers in the model is not divisible by number of GPUs i.e. pipeline parallel (PP) size, you can configure +> uneven PP by setting `--num_layers_in_first_pipeline_stage` and `--num_layers_in_last_pipeline_stage`. +> E.g. for Qwen3-8B with 36 layers and 8 GPUs, you can set both to 3 to get 3-5-5-5-5-5-5-3 layers per GPU. + +## Distillation + +TODO + +## Quantization + +TODO + +## Resources + +## Resources + +- 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) +- 📖 [Documentation](https://nvidia.github.io/Model-Optimizer) +- 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html) +- 🐛 [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md) +- ✨ [File a Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md) diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py new file mode 100644 index 000000000..ab605eacb --- /dev/null +++ b/examples/megatron_bridge/prune_minitron.py @@ -0,0 +1,372 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example script for pruning a GPT / Mamba model using Minitron algorithm on a Megatron-Bridge model (load from HF). + +Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) +while skipping pruning of num_attention_heads using following defaults: + 1024 samples from nemotron-post-training-dataset-v2 for calibration, + at-most 20% depth (num_layers) and 40% width is pruned per prunable hparam (hidden_size, ffn_hidden_size, ...), + top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model. + + torchrun --nproc_per_node 2 prune_minitron.py \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --prune_target_params 6e9 \ + --hparams_to_skip num_attention_heads \ + --output_hf_path /tmp/Qwen3-8B-Pruned-6B + +To see the full usage for advanced configurations, run: + torchrun prune_minitron.py --help +""" + +import argparse +import json +import os + +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider +from megatron.bridge.models.nemotronh.nemotron_h_provider import NemotronHModelProvider +from transformers import AutoConfig, AutoModelForCausalLM + +import modelopt.torch.opt as mto +import modelopt.torch.prune as mtp +import modelopt.torch.utils.distributed as dist +from modelopt.torch.utils import get_supported_datasets, num2hrb, print_rank_0, warn_rank_0 +from modelopt.torch.utils.plugins.mbridge import ( + get_hf_mbridge_calibration_loop, + load_mbridge_model_from_hf, +) +from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--hf_model_name_or_path", type=str, required=True) + parser.add_argument("--trust_remote_code", action="store_true") + + target_group = parser.add_mutually_exclusive_group(required=True) + target_group.add_argument( + "--output_megatron_path", + type=str, + help="Path to save the pruned model in Megatron checkpoint format", + ) + target_group.add_argument( + "--output_hf_path", type=str, help="Path to save the pruned model in HF checkpoint format" + ) + + # Uneven Pipeline Parallelism parameters + parser.add_argument("--num_layers_in_first_pipeline_stage", type=int, default=None) + parser.add_argument("--num_layers_in_last_pipeline_stage", type=int, default=None) + + # Calibration dataset parameters + parser.add_argument( + "--calib_dataset_name", + type=str, + default="nemotron-post-training-dataset-v2", + choices=get_supported_datasets(), + help="Dataset name for calibration", + ) + parser.add_argument( + "--calib_num_samples", type=int, default=1024, help="Number of samples for calibration" + ) + # TODO: Check if mbs>1 is correct or not (because of padding) + parser.add_argument("--calib_mbs", type=int, default=1, help="Calibration micro-batch size") + parser.add_argument("--seq_length", type=int, default=4096) + + # Pruning parameters + parser.add_argument( + "--prune_intermediate_checkpoint", + type=str, + default=None, + help=( + "Path to save/restore intermediate pruning scores for resuming / faster re-run. " + "If not provided, it will default to `/modelopt_pruning_scores.pth`" + ), + ) + + target_group = parser.add_mutually_exclusive_group(required=True) + target_group.add_argument( + "--prune_export_config", + type=str, + help=( + 'Target pruned config as JSON e.g., \'{"hidden_size": 512, "ffn_hidden_size": 2048}\'. ' + f"Supported hyperparameters: {mtp.mcore_minitron.SUPPORTED_HPARAMS}. " + "Cannot be used with --prune_target_params." + ), + ) + target_group.add_argument( + "--prune_target_params", + type=float, + help=( + "Target parameter count for pruning e.g., 6e9 for pruning to 6B params (total params, not active params). " + "Uses Neural Architecture Search (NAS) to find the best pruned model that maximizes the --prune_score_func." + "Cannot be used with --prune_export_config." + ), + ) + + parser.add_argument( + "--prune_score_func", + type=str, + choices=["mmlu_5pct"], + default="mmlu_5pct", + help=( + "Score function to use for NAS-based pruning (--prune_target_params). Currently supported: " + "mmlu_5pct (MMLU on 5% sampled data per subject for faster eval). " + ), + ) + parser.add_argument( + "--max_width_pruning", + type=float, + default=0.4, + help=( + f"Maximum width pruning percentage ({mtp.mcore_minitron.SUPPORTED_HPARAMS - {'num_layers'}}) " + "for NAS-based pruning (--prune_target_params)" + ), + ) + parser.add_argument( + "--max_depth_pruning", + type=float, + default=0.2, + help="Maximum depth pruning percentage ('num_layers') for NAS-based pruning (--prune_target_params)", + ) + parser.add_argument( + "--hparams_to_skip", + nargs="*", + type=str, + default=[], + choices=mtp.mcore_minitron.SUPPORTED_HPARAMS, + help=( + "Space-separated list of hparams to skip for NAS-based pruning (--prune_target_params) " + "e.g. dont prune 'num_attention_heads'" + ), + ) + parser.add_argument( + "--top_k", + type=int, + default=10, + help=( + "Number of top candidates to consider for NAS-based pruning (--prune_target_params). " + "Higher values will take longer to prune but may find a better model." + ), + ) + + args = parser.parse_args() + + if args.prune_intermediate_checkpoint is None: + if args.output_megatron_path: + args.prune_intermediate_checkpoint = ( + f"{args.output_megatron_path}/modelopt_pruning_scores.pth" + ) + elif args.output_hf_path: + args.prune_intermediate_checkpoint = ( + f"{args.output_hf_path}/modelopt_pruning_scores.pth" + ) + print_rank_0( + "No checkpoint provided to cache intermediate pruning scores. " + f"Setting to: {args.prune_intermediate_checkpoint}" + ) + + if args.prune_export_config: + try: + prune_export_config = json.loads(args.prune_export_config) + except json.JSONDecodeError as exc: + raise ValueError( + f"Invalid JSON for --prune_export_config: {args.prune_export_config}" + ) from exc + if not isinstance(prune_export_config, dict): + raise ValueError("--prune_export_config must parse to a dictionary.") + args.prune_export_config = prune_export_config + + print_rank_0("\n==================== Arguments ====================") + for k, v in args.__dict__.items(): + print_rank_0(f"{k:<35} {v}") + print_rank_0("===================================================\n") + + return args + + +def main(): + args = get_args() + pp_size = dist.size() + print_rank_0(f"Setting pipeline_model_parallel_size to {pp_size}") + + if os.path.exists(f"{args.output_megatron_path}/latest_checkpointed_iteration.txt"): + warn_rank_0(f"\nPruned model already exists at {args.output_megatron_path}. Exiting...") + return + elif os.path.exists(f"{args.output_hf_path}/config.json"): + warn_rank_0(f"\nPruned model already exists at {args.output_hf_path}. Exiting...") + return + + bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf( + hf_model_name_or_path=args.hf_model_name_or_path, + trust_remote_code=args.trust_remote_code, + provider_overrides={ + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": pp_size, + "num_layers_in_first_pipeline_stage": args.num_layers_in_first_pipeline_stage, + "num_layers_in_last_pipeline_stage": args.num_layers_in_last_pipeline_stage, + "pipeline_dtype": torch.bfloat16, + "seq_length": args.seq_length, + }, + init_model_parallel=True, + ) + print_rank_0(f"\nPruning {unwrapped_model=}") + print_rank_0( + f"Original model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}" + ) + + forward_loop = get_hf_mbridge_calibration_loop( + model=model, + provider=provider, + hf_model_name_or_path=args.hf_model_name_or_path, + trust_remote_code=args.trust_remote_code, + dataset_name=args.calib_dataset_name, + num_samples=args.calib_num_samples, + micro_batch_size=args.calib_mbs, + ) + + pruning_config = { + "forward_loop": forward_loop, + "checkpoint": args.prune_intermediate_checkpoint, + } + if args.prune_target_params is not None: + # Restrict search space to a smaller set of candidates + # NOTE: You can reduce the divisors and increase config['top_k'] to potentially find a better model. + ss_config = mtp.mcore_minitron.get_mcore_minitron_config( + hidden_size_divisor=256, + ffn_hidden_size_divisor=512, + mamba_head_dim_divisor=8, + num_moe_experts_divisor=8, + num_layers_divisor=2, + ) + + pruning_constraints = {"params": args.prune_target_params} + print_rank_0( + f"Using NAS-based automatic pruning with score function: {args.prune_score_func}" + "You can change this to be any other metric you want to maximize (e.g. negative validation loss)." + ) + + def score_func_mmlu(m): + return megatron_mmlu(m, tokenizer, percentage=0.05) + + pruning_config["score_func"] = score_func_mmlu + pruning_config["max_width_pruning"] = args.max_width_pruning + pruning_config["max_depth_pruning"] = args.max_depth_pruning + pruning_config["hparams_to_skip"] = args.hparams_to_skip + pruning_config["top_k"] = args.top_k + elif args.prune_export_config is not None: + # Less restrictive search space for manual pruning + ss_config = mtp.mcore_minitron.get_mcore_minitron_config( + hidden_size_divisor=64, + ffn_hidden_size_divisor=64, + mamba_head_dim_divisor=8, + num_moe_experts_divisor=8, + num_layers_divisor=1, + ) + + pruning_constraints = {"export_config": args.prune_export_config} + print_rank_0(f"Pruning constraints: {pruning_constraints}") + + unwrapped_model, pruning_scores = mtp.prune( # in-place pruning + unwrapped_model, + mode=[("mcore_minitron", ss_config)], # type: ignore[arg-type] + constraints=pruning_constraints, + dummy_input=None, + config=pruning_config, + ) + # Remove unnecessary modelopt_state since ckpt is homogeneous + if mto.ModeloptStateManager.has_state_for_mode_type("prune", model=unwrapped_model): + mto.ModeloptStateManager.remove_state(unwrapped_model) + if isinstance(provider, MambaModelProvider): + provider.hybrid_override_pattern = unwrapped_model.hybrid_override_pattern + print_rank_0(f"\nPruned {unwrapped_model=}") + print_rank_0( + f"Pruned model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}" + ) + + if args.output_megatron_path is not None: + print_rank_0( + f"Saved pruned model to {args.output_megatron_path} in Megatron checkpoint format" + ) + + # NOTE: Issue with NemotronH tokenizer's len() hence using use_fast=True as a WAR + use_fast_tokenizer = isinstance(provider, NemotronHModelProvider) + bridge.save_megatron_model( + model, + args.output_megatron_path, + hf_tokenizer_path=args.hf_model_name_or_path, + hf_tokenizer_kwargs={ + "trust_remote_code": args.trust_remote_code, + "use_fast": use_fast_tokenizer, + }, + ) + print_rank_0( + f"Saved pruned model to {args.output_megatron_path} in Megatron checkpoint format" + ) + else: + print_rank_0(f"Saving pruned model to {args.output_hf_path} in HF checkpoint format") + + # [WAR] Hacky way to save pruned HF model until Megatron-Bridge natively supports it + bridge.hf_pretrained.save_artifacts(args.output_hf_path) + hf_cfg = AutoConfig.from_pretrained( + args.output_hf_path, trust_remote_code=args.trust_remote_code + ) + mcore_cfg = unwrapped_model.config + + hf_cfg.hidden_size = mcore_cfg.hidden_size + hf_cfg.intermediate_size = mcore_cfg.ffn_hidden_size + hf_cfg.num_attention_heads = mcore_cfg.num_attention_heads + hf_cfg.head_dim = mcore_cfg.kv_channels + hf_cfg.num_key_value_heads = mcore_cfg.num_query_groups + if hasattr(hf_cfg, "mamba_num_heads"): + hf_cfg.mamba_num_heads = mcore_cfg.mamba_num_heads + if hasattr(hf_cfg, "mamba_head_dim"): + hf_cfg.mamba_head_dim = mcore_cfg.mamba_head_dim + if hasattr(hf_cfg, "moe_intermediate_size"): + hf_cfg.moe_intermediate_size = mcore_cfg.moe_ffn_hidden_size + if hasattr(hf_cfg, "moe_shared_expert_intermediate_size"): + hf_cfg.moe_shared_expert_intermediate_size = ( + mcore_cfg.moe_shared_expert_intermediate_size + ) + if hasattr(hf_cfg, "num_experts"): + hf_cfg.num_experts = mcore_cfg.num_moe_experts + if hasattr(hf_cfg, "n_routed_experts"): + hf_cfg.n_routed_experts = mcore_cfg.num_moe_experts + if hasattr(hf_cfg, "n_shared_experts"): + hf_cfg.n_shared_experts = ( + mcore_cfg.moe_shared_expert_intermediate_size // mcore_cfg.moe_ffn_hidden_size + ) + if hasattr(hf_cfg, "layer_types"): + kept_layer_nums = pruning_scores["sorted_layers"][: mcore_cfg.num_layers] # 1-indexed + hf_cfg.layer_types = [ + lt for i, lt in enumerate(hf_cfg.layer_types) if i + 1 in kept_layer_nums + ] + hf_cfg.num_hidden_layers = mcore_cfg.num_layers + + # Save dummy pruned HF model to get the correct bridge for saving pruned weights + AutoModelForCausalLM.from_config(hf_cfg).save_pretrained(args.output_hf_path) + pruned_bridge = AutoBridge.from_hf_pretrained(args.output_hf_path) + pruned_bridge.save_hf_weights(model, args.output_hf_path) + print_rank_0(f"Saved pruned model to {args.output_hf_path} in HF checkpoint format") + + print_rank_0("Done!") + + +if __name__ == "__main__": + dist.setup() + try: + main() + finally: + dist.cleanup() diff --git a/examples/pruning/README.md b/examples/pruning/README.md index d9c17cdc3..95586b3d9 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -6,7 +6,7 @@ Pruning can involve removal (prune) of Linear and Conv layers; and Transformer a This section focuses on applying Model Optimizer's state-of-the-art complementary pruning modes to enable you to search for the best subnet architecture from your provided base model: -1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT (and later extended to Mamba, MoE, and Hybrid Transformer Mamba) models in NVIDIA Megatron-LM or NeMo framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model. +1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT (and later extended to Mamba, MoE, and Hybrid Transformer Mamba) models in NVIDIA Megatron-LM (M-LM) or Megatron-Bridge (M-Bridge) framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model. 1. FastNAS: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints. 1. GradNAS: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model's linear layers and attention heads to meet the given constraints. @@ -25,7 +25,7 @@ This section focuses on applying Model Optimizer's state-of-the-art complementar ## Pre-Requisites -For Minitron pruning for Megatron-LM / NeMo models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:25.11`) which has all the dependencies installed. Make sure to upgrade Model Optimizer to the latest version using `pip`. +For Minitron pruning for Megatron-LM / Megatron-Bridge models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02`) which has all the dependencies installed. For FastNAS pruning for PyTorch Computer Vision models, no additional dependencies are required. @@ -39,45 +39,49 @@ To prune your model, you can simply call the `mtp.prune` API and save the pruned ### Minitron -Minitron pruning supports two modes: +Minitron pruning supports two types: 1. **Manual Pruning**: Manually specify the target dimensions for each pruning axis (e.g., `constraints = {"export_config": {"hidden_size": 3072, "ffn_hidden_size": 9216}}`) 2. **NAS-based Auto Pruning (New)**: Specify a target parameter count (e.g., `constraints = {"params": 6e9}`) and let the algorithm automatically search for the best architecture that maximizes a user-defined score function (e.g. MMLU, negative validation loss, etc.) -Please see example snippets of both modes for Minitron pruning on Megatron-Core GPT model below. For end-to-end examples script (M-LM / NeMo framework), please refer to the examples below. +Please see example snippets of both modes for Minitron pruning on Megatron-Bridge Qwen3-8B model below. For end-to-end examples script (M-LM / M-Bridge framework), please refer to the examples below. #### Common Setup ```python +import torch import modelopt.torch.prune as mtp -from megatron.core.models.gpt import GPTModel -from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec -from megatron.core.transformer.transformer_config import TransformerConfig - -# Load the Megatron-Core GPTModel MambaModel with ModelOpt transformer layer spec -model_config = TransformerConfig(...) -model = GPTModel( - config=model_config, - transformer_layer_spec=get_gpt_modelopt_spec(model_config, remap_te_layernorm=True), - ... +from modelopt.torch.utils.plugins.mbridge import ( + get_hf_mbridge_calibration_loop, + load_mbridge_model_from_hf, ) -# Set up the forward loop to run on 512-1024 train samples -# For Megatron-LM framework, you can use the following utility function -from megatron.training.training import evaluate_and_print_results +# Import the Megatron-Bridge Qwen3-8B model from Hugging Face checkpoint +bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf( + hf_model_name_or_path="Qwen/Qwen3-8B", + provider_overrides={ + "pipeline_model_parallel_size": 1, + "pipeline_dtype": torch.bfloat16, + "seq_length": 4096, + }, +) -def forward_loop(_): - evaluate_and_print_results(prefix, forward_step, train_iterator, model, ...) +# Set up the forward loop to run on 1024 train samples +forward_loop = get_hf_mbridge_calibration_loop( + model=model, + provider=provider, + hf_model_name_or_path="Qwen/Qwen3-8B", + dataset_name="nemotron-post-training-dataset-v2", + num_samples=1024, +) -# Run the pruning process (if model is a list then pass model[0] to the prune API) -# Save minitron scores at checkpoint so we can re-run pruning with different constraints without running the forward loop again -# NOTE: Skip checkpoint on re-running if you want to change the dataset and re-calibrate -model, pruning_scores = mtp.prune( - model, +# Run pruning on the unwrapped model +mtp.prune( # in-place pruning + unwrapped_model, mode="mcore_minitron", - constraints=constraints, + constraints=constraints, # Shown below for both types dummy_input=None, # Not used - config=config, + config=config, # Shown below for both types ) ``` @@ -90,7 +94,8 @@ This mode can be useful when you know the exact dimensions you want to prune to ```python # Specify the pruning constraints (Check Support Matrix for available pruning dimensions) -constraints = {"export_config": {"hidden_size": 3072, "ffn_hidden_size": 9216}} +# Save minitron scores at checkpoint so we can re-run pruning with different constraints without running the forward loop again +constraints = {"export_config": {"num_layers": 32, "hidden_size": 3584, "ffn_hidden_size": 10240}} config = {"forward_loop": forward_loop, "checkpoint": "/path/to/cache/pruning/scores.pth"} mtp.prune(...) @@ -119,16 +124,17 @@ def score_func(m): return megatron_mmlu(m, tokenizer, percentage=0.05) # 5% sampled data for faster eval # Specify target parameter count and configure the auto pruning algorithm +# Save minitron scores at checkpoint so we can resume pruning without running the forward loop again constraints = {"params": 6e9} # Prune to 6B parameters config = { "forward_loop": forward_loop, "checkpoint": "/path/to/cache/pruning/scores.pth", "score_func": score_func, # Optional: Configure search space constraints (showing defaults) - "max_width_pruning": 0.4, # Maximum 40% per width pruning hparam + "max_width_pruning": 0.4, # Maximum 40% per width pruning hparams (hidden_size, ffn_hidden_size, etc.) "max_depth_pruning": 0.2, # Maximum 20% per depth pruning hparam (num_layers) "hparams_to_skip": [], # Disable pruning specific hparams, e.g., ["num_attention_heads"] - "top_k": 10, # Number of top architectures to evaluate (use 20 for better results at the cost of 2x time) + "top_k": 10, # Number of top architectures to evaluate (using 20 may result in better pruned model at the cost of 2x time) } mtp.prune(...) @@ -160,7 +166,7 @@ ss_config = mtp.mcore_minitron.get_mcore_minitron_config( ) # Use the custom search space config -mtp.prune(model, mode=[("mcore_minitron", ss_config)], ...) +mtp.prune(unwrapped_model, mode=[("mcore_minitron", ss_config)], ...) ``` If your model parameters are already sorted and you just want to prune the weights, you can skip the sorting step by setting `"skip_sorting": True` in `config` instead of passing `forward_loop`. @@ -169,20 +175,20 @@ If your model parameters are already sorted and you just want to prune the weigh | **Algorithm** | **Model** | **Pruning Constraints** | | :---: | :---: | :---: | -| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid LLM Models1 | **Manual:** `export_config` with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) pruned values
**Auto:** `params` (requires `score_func` in config) | +| Minitron | Megatron-core (M-LM, M-Bridge) based GPT / Mamba / MoE / Hybrid LLM Models1 | **Manual:** `export_config` with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) pruned values
**Auto:** `params` (requires `score_func` in config) | | FastNAS | Computer Vision models | `flops`, `params` | | GradNAS | HuggingFace BERT, GPT-J | `flops`, `params` | -> *1.Only Pipeline Parallel models are supported. Hugging Face models can be converted to Megatron-LM/NeMo format and used subsequently.* +> *1.Only models in Pipeline Parallelism (PP) are supported. Hugging Face models can be imported into M-Bridge/M-LM format as long as they are [supported](https://docs.nvidia.com/nemo/megatron-bridge/latest/index.html#supported-models) by the framework.* ## Examples -### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano) +### Minitron Pruning for Megatron-Bridge/ Megatron-LM Framework LLMs (e.g. Qwen 3, Nemotron Nano) -Checkout the Minitron pruning example for the [Megatron-LM Framework](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt#-pruning) or [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/pruning/pruning.html) which showcases the usage of the powerful Minitron pruning algorithm developed by NVIDIA Research for pruning LLMs like Llama-3.1-8B, Qwen3-8B, Nemotron-Nano-9B-v2, Nemotron-3-Nano-30B-A3B, etc. +Checkout the Minitron pruning example for [Megatron-Bridge Framework](../megatron_bridge/README.md#pruning) or [Megatron-LM Framework](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt#-pruning) which showcases the usage of the powerful Minitron pruning algorithm developed by NVIDIA Research for pruning LLMs like Llama-3.1-8B, Qwen3-8B, Nemotron-Nano-9B-v2, Nemotron-3-Nano-30B-A3B, etc. Both frameworks support importing from a Hugging Face pretrained checkpoint. -You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen3-8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial. +\[Deprecated\] You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen3-8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial. Some of the models pruned using Minitron method followed by distillation and post-training are: diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 7908ec514..636625394 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -97,12 +97,13 @@ __all__ = [ "create_forward_loop", "get_dataset_dataloader", + "get_dataset_samples", "get_max_batch_size", "get_supported_datasets", ] -def _get_dataset_samples(dataset_name: str, num_samples: int) -> list[str]: +def get_dataset_samples(dataset_name: str, num_samples: int) -> list[str]: """Load a portion of train dataset with the dataset name and a given size. Args: @@ -211,7 +212,7 @@ def get_dataset_dataloader( all_samples = [] for ds_name, num_sample in zip(dataset_name, num_samples): - samples = _get_dataset_samples(ds_name, num_sample) + samples = get_dataset_samples(ds_name, num_sample) all_samples.extend(samples) batch_encoded = tokenizer.batch_encode_plus( diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index b70a2ea6d..7922b6880 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -23,6 +23,7 @@ from contextlib import suppress from datetime import timedelta from typing import Any +from warnings import warn import torch import torch.distributed @@ -76,7 +77,8 @@ def local_rank() -> int: """Returns the local rank of the current process.""" if "LOCAL_RANK" in os.environ: return int(os.environ["LOCAL_RANK"]) - raise RuntimeError("LOCAL_RANK environment variable not found.") + warn("LOCAL_RANK environment variable not found. Using global rank instead.") + return rank() def is_master(group=None) -> bool: diff --git a/modelopt/torch/utils/plugins/__init__.py b/modelopt/torch/utils/plugins/__init__.py index 517c59914..fd00e423f 100644 --- a/modelopt/torch/utils/plugins/__init__.py +++ b/modelopt/torch/utils/plugins/__init__.py @@ -25,3 +25,8 @@ with import_plugin("megatron_preprocess_data"): from .megatron_preprocess_data import * + +# NOTE: Dont pre-import megatron bridge plugin here to avoid circular dependency issues. +# We dont register anything so this isnt a problem. +# with import_plugin("megatron bridge"): +# from .mbridge import * diff --git a/modelopt/torch/utils/plugins/mbridge.py b/modelopt/torch/utils/plugins/mbridge.py new file mode 100644 index 000000000..59ed7e6e6 --- /dev/null +++ b/modelopt/torch/utils/plugins/mbridge.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Megatron-Bridge plugins for using with Model-Optimizer.""" + +from collections.abc import Callable +from typing import Any + +import torch.nn as nn +from datasets import DatasetDict +from megatron.bridge import AutoBridge +from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig +from megatron.bridge.data.loaders import setup_data_iterators +from megatron.bridge.data.utils import get_dataset_provider +from megatron.bridge.models.gpt_provider import GPTModelProvider, modelopt_transformer_layer_spec +from megatron.bridge.models.hf_pretrained.utils import is_safe_repo +from megatron.bridge.models.mamba.mamba_provider import ( + MambaModelProvider, + modelopt_mamba_stack_spec, +) +from megatron.bridge.models.nemotronh.nemotron_h_provider import NemotronHModelProvider +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + OptimizerConfig, + SchedulerConfig, + TrainingConfig, + runtime_config_update, +) +from megatron.bridge.training.eval import evaluate_and_print_results +from megatron.bridge.training.gpt_step import forward_step +from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.tokenizers.config import TokenizerConfig +from megatron.core.models.gpt import GPTModel +from megatron.core.models.mamba import MambaModel +from megatron.core.transformer.module import MegatronModule +from megatron.core.utils import unwrap_model +from transformers import AutoTokenizer + +from modelopt.torch.utils import get_dataset_samples, print_rank_0 + +__all__ = ["get_hf_mbridge_calibration_loop", "load_mbridge_model_from_hf"] + + +def load_mbridge_model_from_hf( + *, + hf_model_name_or_path: str, + trust_remote_code: bool = False, + provider_overrides: dict[str, Any] | None = None, + init_model_parallel: bool = True, +) -> tuple[ + AutoBridge, + GPTModelProvider | MambaModelProvider, + list[MegatronModule], + GPTModel | MambaModel, + AutoTokenizer, +]: + """Load a Megatron-Bridge model from HF. + + Args: + hf_model_name_or_path: The name or path of the HF model. + trust_remote_code: Whether to trust remote code. + provider_overrides: Overrides for the provider. + init_model_parallel: Whether to initialize model parallel. + + Returns: + A tuple of (bridge, provider, model, unwrapped_model, tokenizer). + """ + print_rank_0(f"Loading Megatron-Bridge model from HF: {hf_model_name_or_path}") + bridge = AutoBridge.from_hf_pretrained( + hf_model_name_or_path, + trust_remote_code=is_safe_repo( + trust_remote_code=trust_remote_code, + hf_path=hf_model_name_or_path, + ), + ) + + provider = bridge.to_megatron_provider() + if provider_overrides: + for key, value in provider_overrides.items(): + assert hasattr(provider, key), f"{type(provider)} does not have attribute {key}" + setattr(provider, key, value) + + print_rank_0("Setting ModelOpt spec for model provider") + if isinstance(provider, MambaModelProvider): + provider.mamba_stack_spec = modelopt_mamba_stack_spec + else: + provider.transformer_layer_spec = modelopt_transformer_layer_spec + + provider.finalize() + if init_model_parallel: + provider.initialize_model_parallel(seed=0) + + model = provider.provide_distributed_model(wrap_with_ddp=False) + assert len(model) == 1 + unwrapped_model = unwrap_model(model[0]) + assert isinstance(unwrapped_model, (GPTModel, MambaModel)) + + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name_or_path, trust_remote_code=trust_remote_code + ) + + return bridge, provider, model, unwrapped_model, tokenizer + + +def _get_dataset_cfg(dataset_name: str, num_samples: int, seq_length: int) -> HFDatasetConfig: + """Get a dataset config for the dataset.""" + dataset = get_dataset_samples(dataset_name, num_samples) + dataset_cfg = HFDatasetConfig( + dataset_name=f"{dataset_name}_{num_samples}", + dataset_dict=DatasetDict({"train": dataset}), + process_example_fn=lambda example, tokenizer: {"input": example, "output": ""}, + seq_length=seq_length, + dataloader_type="batch", + num_workers=1, + do_validation=False, + do_test=False, + val_proportion=None, + split_val_from_train=False, + rewrite=False, + ) + + return dataset_cfg + + +def get_hf_mbridge_calibration_loop( + *, + model: list[MegatronModule], + provider: GPTModelProvider | MambaModelProvider, + hf_model_name_or_path: str, + trust_remote_code: bool = False, + dataset_name: str = "nemotron-post-training-dataset-v2", + num_samples: int = 512, + micro_batch_size: int = 1, +) -> Callable[[nn.Module], None]: + """Get a modelopt calibration loop for a Megatron-Bridge model. + + Args: + model: The model to calibrate. + provider: The provider to use for the model. + hf_model_name_or_path: The name or path of the HF model. + trust_remote_code: Whether to trust remote code. + dataset_name: The name of the dataset to use for evaluation. + num_samples: The number of samples to use for evaluation. + micro_batch_size: The micro batch size to use for evaluation. + + Returns: + A function that can be used to calibrate the model with a modelopt.torch API. + """ + global_batch_size = micro_batch_size + num_iters = num_samples // global_batch_size + + # NOTE: Issue with NemotronH tokenizer's len() hence using use_fast=True as a WAR + use_fast_tokenizer = isinstance(provider, NemotronHModelProvider) + + cfg = ConfigContainer( + model=provider, + train=TrainingConfig( + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + train_iters=num_iters, + eval_iters=num_iters, + skip_train=True, + ), + dataset=_get_dataset_cfg(dataset_name, num_samples, provider.seq_length), + tokenizer=TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model=hf_model_name_or_path, + hf_tokenizer_kwargs={ + "trust_remote_code": trust_remote_code, + "use_fast": use_fast_tokenizer, + }, + ), + # Unused + optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False), + scheduler=SchedulerConfig(lr_decay_style="constant"), + logger=LoggerConfig(), + checkpoint=CheckpointConfig(), + ) + runtime_config_update(cfg) + + state = GlobalState() + state.cfg = cfg + + dataset_provider = get_dataset_provider(cfg.dataset) + + def _train_valid_test_datasets_provider( + train_val_test_num_samples: tuple, dataset_cfg: HFDatasetConfig + ): + return dataset_provider(train_val_test_num_samples, dataset_cfg, tokenizer=state.tokenizer) + + train_data_iterator, _, _ = setup_data_iterators( + cfg=cfg, + train_state=state.train_state, + model_length=len(model), + train_valid_test_datasets_provider=_train_valid_test_datasets_provider, + ) + + def forward_loop(m): + evaluate_and_print_results( + state, + prefix="iteration 1", + forward_step_func=forward_step, + data_iterator=train_data_iterator, + model=model, + config=cfg, + verbose=True, + write_to_tensorboard=False, + ) + + return forward_loop From 0217833e2da5916e08647eb22a9888b691125386 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 22 Jan 2026 11:21:29 -0800 Subject: [PATCH 2/2] Support chat template for calibration samples + minor fixes Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/megatron_bridge/README.md | 4 +-- examples/megatron_bridge/prune_minitron.py | 22 +++++++----- examples/pruning/README.md | 3 +- .../torch/prune/plugins/mcore_minitron.py | 5 ++- modelopt/torch/utils/dataset_utils.py | 36 +++++++++++++++++-- modelopt/torch/utils/plugins/mbridge.py | 27 +++++++++++--- 6 files changed, 78 insertions(+), 19 deletions(-) diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index fcae037c6..2e6fbdff9 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -40,7 +40,7 @@ torchrun --nproc_per_node 2 /opt/Model-Optimizer/examples/megatron_bridge/prune_ To see the full usage for advanced configurations, run: ```bash -torchrun /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help +python /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help ``` > [!TIP] @@ -58,8 +58,6 @@ TODO ## Resources -## Resources - - 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) - 📖 [Documentation](https://nvidia.github.io/Model-Optimizer) - 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html) diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index ab605eacb..4a2d279a7 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -27,7 +27,7 @@ --output_hf_path /tmp/Qwen3-8B-Pruned-6B To see the full usage for advanced configurations, run: - torchrun prune_minitron.py --help + python prune_minitron.py --help """ import argparse @@ -81,8 +81,11 @@ def get_args() -> argparse.Namespace: parser.add_argument( "--calib_num_samples", type=int, default=1024, help="Number of samples for calibration" ) - # TODO: Check if mbs>1 is correct or not (because of padding) - parser.add_argument("--calib_mbs", type=int, default=1, help="Calibration micro-batch size") + # TODO: Add support for pre-training dataset (pre-tokenized) + # TODO: only allow mbs>1 for pretraining dataset + parser.add_argument( + "--calib_mbs", type=int, default=1, choices=[1], help="Calibration micro-batch size" + ) parser.add_argument("--seq_length", type=int, default=4096) # Pruning parameters @@ -197,15 +200,16 @@ def get_args() -> argparse.Namespace: return args -def main(): - args = get_args() +def main(args: argparse.Namespace): pp_size = dist.size() print_rank_0(f"Setting pipeline_model_parallel_size to {pp_size}") - if os.path.exists(f"{args.output_megatron_path}/latest_checkpointed_iteration.txt"): + if args.output_megatron_path and os.path.exists( + f"{args.output_megatron_path}/latest_checkpointed_iteration.txt" + ): warn_rank_0(f"\nPruned model already exists at {args.output_megatron_path}. Exiting...") return - elif os.path.exists(f"{args.output_hf_path}/config.json"): + elif args.output_hf_path and os.path.exists(f"{args.output_hf_path}/config.json"): warn_rank_0(f"\nPruned model already exists at {args.output_hf_path}. Exiting...") return @@ -230,6 +234,7 @@ def main(): forward_loop = get_hf_mbridge_calibration_loop( model=model, provider=provider, + tokenizer=tokenizer, hf_model_name_or_path=args.hf_model_name_or_path, trust_remote_code=args.trust_remote_code, dataset_name=args.calib_dataset_name, @@ -365,8 +370,9 @@ def score_func_mmlu(m): if __name__ == "__main__": + args = get_args() dist.setup() try: - main() + main(args) finally: dist.cleanup() diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 95586b3d9..41ca6249d 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -70,6 +70,7 @@ bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf forward_loop = get_hf_mbridge_calibration_loop( model=model, provider=provider, + tokenizer=tokenizer, hf_model_name_or_path="Qwen/Qwen3-8B", dataset_name="nemotron-post-training-dataset-v2", num_samples=1024, @@ -90,7 +91,7 @@ mtp.prune( # in-place pruning #### 1. Manual Pruning -This mode can be useful when you know the exact dimensions you want to prune to (e.g. fitting a specific latency / memory budget). +This mode can be useful when you know the exact dimensions you want to prune to (e.g. fitting a specific latency / memory budget). Alternatively, you can also use this mode to export top-K architectures (searched using NAS-based auto pruning) and perform short Knowledge Distillation on them before selecting the best architecture. ```python # Specify the pruning constraints (Check Support Matrix for available pruning dimensions) diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 97ff1991d..3d476ec8f 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -415,7 +415,10 @@ def search_best_arch_by_params(self) -> dict: ) # 2. Perform grid-search over the search space to find subnets fitting the constraints - if max_params not in self.top_k_candidates_per_constraint: + if ( + max_params not in self.top_k_candidates_per_constraint + or len(self.top_k_candidates_per_constraint[max_params]) != top_k + ): max_num_layers = self.model.get_hparam("num_layers").max search_space_configs = MCoreMinitronSearcher._generate_search_space_combos( hp_choices, diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 636625394..b4a0acddd 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -57,6 +57,7 @@ "split": ["stem", "chat", "math", "code"], }, "preprocess": lambda sample: "\n".join(turn["content"] for turn in sample["messages"]), + "chat_key": "messages", }, "nemotron-post-training-dataset-v1": { "config": { @@ -64,6 +65,7 @@ "split": ["stem", "chat", "math", "code", "tool_calling"], }, "preprocess": lambda sample: "\n".join(turn["content"] for turn in sample["messages"]), + "chat_key": "messages", }, "magpie": { "config": { @@ -71,6 +73,7 @@ "split": ["train"], }, "preprocess": lambda sample: "\n".join(turn["value"] for turn in sample["conversations"]), + "chat_key": "conversations", }, "cnn_dailymail": { "config": {"path": "cnn_dailymail", "name": "3.0.0", "split": ["train"]}, @@ -92,6 +95,10 @@ "config": {"path": "c4", "name": "en", "split": ["train"]}, "preprocess": lambda sample: sample["text"], }, + "wikitext": { + "config": {"path": "wikitext", "name": "wikitext-103-v1", "split": ["train"]}, + "preprocess": lambda sample: sample["text"], + }, } __all__ = [ @@ -103,12 +110,21 @@ ] -def get_dataset_samples(dataset_name: str, num_samples: int) -> list[str]: +def get_dataset_samples( + dataset_name: str, + num_samples: int, + *, + apply_chat_template: bool = False, + tokenizer: "PreTrainedTokenizerBase | None" = None, +) -> list[str]: """Load a portion of train dataset with the dataset name and a given size. Args: dataset_name: Name of the dataset to load. num_samples: Number of samples to load from the dataset. + apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset). + tokenizer: Tokenizer to use for applying the chat template to the samples. + No tokenization is done and plain text is still returned. Returns: Samples: The list of samples. @@ -123,6 +139,15 @@ def get_dataset_samples(dataset_name: str, num_samples: int) -> list[str]: from datasets import load_dataset dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name] + if apply_chat_template: + if "chat_key" not in dataset_config: + warn( + f"Dataset {dataset_name} does not support chat template. Chat template will not be applied." + ) + elif tokenizer is None: + raise ValueError("Tokenizer is required when applying chat template.") + print(f"Applying chat template to dataset {dataset_name}") + # It's unfortunate that the load_dataset function does not support split a list while streaming. # So we need to load the dataset for each split. config = dataset_config["config"].copy() @@ -148,7 +173,14 @@ def get_dataset_samples(dataset_name: str, num_samples: int) -> list[str]: break # Apply preprocess function to the sample - samples.append(dataset_config["preprocess"](sample)) + if apply_chat_template and "chat_key" in dataset_config: + sample = tokenizer.apply_chat_template( # type: ignore[union-attr] + sample[dataset_config["chat_key"]], tokenize=False + ) + else: + sample = dataset_config["preprocess"](sample) + if sample != "": # wikitext has some empty samples + samples.append(sample) return samples diff --git a/modelopt/torch/utils/plugins/mbridge.py b/modelopt/torch/utils/plugins/mbridge.py index 59ed7e6e6..09edfb7aa 100644 --- a/modelopt/torch/utils/plugins/mbridge.py +++ b/modelopt/torch/utils/plugins/mbridge.py @@ -45,6 +45,7 @@ from megatron.bridge.training.tokenizers.config import TokenizerConfig from megatron.core.models.gpt import GPTModel from megatron.core.models.mamba import MambaModel +from megatron.core.parallel_state import get_data_parallel_group from megatron.core.transformer.module import MegatronModule from megatron.core.utils import unwrap_model from transformers import AutoTokenizer @@ -115,9 +116,17 @@ def load_mbridge_model_from_hf( return bridge, provider, model, unwrapped_model, tokenizer -def _get_dataset_cfg(dataset_name: str, num_samples: int, seq_length: int) -> HFDatasetConfig: +def _get_dataset_cfg( + dataset_name: str, + num_samples: int, + seq_length: int, + apply_chat_template: bool = True, + tokenizer: AutoTokenizer | None = None, +) -> HFDatasetConfig: """Get a dataset config for the dataset.""" - dataset = get_dataset_samples(dataset_name, num_samples) + dataset = get_dataset_samples( + dataset_name, num_samples, apply_chat_template=apply_chat_template, tokenizer=tokenizer + ) dataset_cfg = HFDatasetConfig( dataset_name=f"{dataset_name}_{num_samples}", dataset_dict=DatasetDict({"train": dataset}), @@ -129,7 +138,7 @@ def _get_dataset_cfg(dataset_name: str, num_samples: int, seq_length: int) -> HF do_test=False, val_proportion=None, split_val_from_train=False, - rewrite=False, + rewrite=True, ) return dataset_cfg @@ -139,6 +148,7 @@ def get_hf_mbridge_calibration_loop( *, model: list[MegatronModule], provider: GPTModelProvider | MambaModelProvider, + tokenizer: AutoTokenizer, hf_model_name_or_path: str, trust_remote_code: bool = False, dataset_name: str = "nemotron-post-training-dataset-v2", @@ -150,6 +160,7 @@ def get_hf_mbridge_calibration_loop( Args: model: The model to calibrate. provider: The provider to use for the model. + tokenizer: The tokenizer to use for the model. hf_model_name_or_path: The name or path of the HF model. trust_remote_code: Whether to trust remote code. dataset_name: The name of the dataset to use for evaluation. @@ -159,6 +170,7 @@ def get_hf_mbridge_calibration_loop( Returns: A function that can be used to calibrate the model with a modelopt.torch API. """ + # TODO: make global_batch_size larger than micro_batch_size for PP interleaving global_batch_size = micro_batch_size num_iters = num_samples // global_batch_size @@ -174,7 +186,13 @@ def get_hf_mbridge_calibration_loop( eval_iters=num_iters, skip_train=True, ), - dataset=_get_dataset_cfg(dataset_name, num_samples, provider.seq_length), + dataset=_get_dataset_cfg( + dataset_name, + num_samples, + provider.seq_length, + apply_chat_template=True, + tokenizer=tokenizer, + ), tokenizer=TokenizerConfig( tokenizer_type="HuggingFaceTokenizer", tokenizer_model=hf_model_name_or_path, @@ -206,6 +224,7 @@ def _train_valid_test_datasets_provider( train_state=state.train_state, model_length=len(model), train_valid_test_datasets_provider=_train_valid_test_datasets_provider, + dp_group=get_data_parallel_group(), ) def forward_loop(m):