Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ NVIDIA Model Optimizer Changelog (Linux)
- Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow.
- Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model.
- Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models.
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
- Add support for image-text data calibration in PTQ for Nemotron VL models.

0.41 (2026-01-19)
^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 3 additions & 1 deletion examples/llm_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
| QWen3 MOE, Next <sup>6</sup> | ✅ | - | - | - | ✅ |
| QwQ | ✅ | - | - | - | ✅ |
| DeepSeek V3, R1, V3.1, V3.2<sup>7</sup> | - | - | - | - | ✅ |
| GLM-4.7<sup>8</sup> | ✅ | - | - | - | ✅ |
| Kimi K2 | - | - | - | - | ✅ |
| T5 | ✅ | ✅ | ✅ | ✅ | - |
| Whisper | ✅ | ❌ | ❌ | ❌ | - |
Expand All @@ -121,7 +122,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
> *<sup>4.</sup>For some models, KV cache quantization may result in a higher accuracy penalty.* \
> *<sup>5.</sup>A selective set of the popular models are internally tested. The actual model support list may be longer. NVFP4 inference requires Blackwell GPUs and TensorRT-LLM v0.17 or later* \
> *<sup>6.</sup>Some models currently support export to HF format only.* \
> *<sup>7.</sup>[PTQ for DeepSeek](../deepseek/README.md)*
> *<sup>7.</sup>[PTQ for DeepSeek](../deepseek/README.md)* \
> *<sup>8.</sup>GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.*

> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead.*

Expand Down
109 changes: 109 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import copy
import glob
import inspect
import json
import os
import re
import shutil
import sys
import warnings
Expand All @@ -27,6 +29,7 @@
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import get_max_memory
from safetensors.torch import load_file
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Expand Down Expand Up @@ -314,6 +317,106 @@ def get_processor(
return None


def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]:
"""Load MTP weights from separate safetensors if needed (e.g., GLM-4.7).

Some models store additional layers in separate safetensors files with non-standard
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
files even though they're referenced in model.safetensors.index.json.

This function detects such cases and explicitly loads the missing weights.

Args:
model: The loaded model that may be missing weights
model_path: Path to the model directory

Returns:
List of layer prefixes that were loaded from non-standard safetensors files.
These layers should typically be excluded from quantization.
Empty list if no additional weights were loaded.
"""
model_path = Path(model_path)
index_file = model_path / "model.safetensors.index.json"
mtp_layer_prefixes: list[str] = []

if not index_file.exists():
return mtp_layer_prefixes

# Load the index to find all referenced safetensors files
with open(index_file) as f:
index = json.load(f)

# Find all unique safetensors files referenced
all_files = set(index["weight_map"].values())

# Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern)
standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors")
non_standard_files = [f for f in all_files if not standard_pattern.match(f)]

if not non_standard_files:
return mtp_layer_prefixes

# Check which non-standard files exist and have missing weights
model_state = model.state_dict()
total_loaded = 0

for filename in non_standard_files:
filepath = model_path / filename
if not filepath.exists():
continue

# Find keys that should be in this file
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]

# Check which are missing from the model
missing_keys = [k for k in expected_keys if k not in model_state]

if not missing_keys:
# Even if weights are loaded, record the layer prefixes for exclusion
# Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight")
for key in expected_keys:
# Extract layer prefix like "model.layers.92" or "layers.92"
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
if prefix not in mtp_layer_prefixes:
mtp_layer_prefixes.append(prefix)
break
continue

print(f"Loading {len(missing_keys)} missing weights from {filename}...")

# Extract unique layer prefixes for exclusion from quantization
for key in missing_keys:
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
if prefix not in mtp_layer_prefixes:
mtp_layer_prefixes.append(prefix)
break

# Load the weights to CPU first, load_state_dict will handle device placement
weights = load_file(str(filepath), device="cpu")
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}

# Load into model
missing, unexpected = model.load_state_dict(weights_to_load, strict=False)
total_loaded += len(weights_to_load)

if missing:
print(f" Warning: {len(missing)} keys still missing after loading {filename}")

if total_loaded > 0:
print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files")

if mtp_layer_prefixes:
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")

return mtp_layer_prefixes


def get_dtype(dtype):
if dtype == "bf16":
dtype = torch.bfloat16
Expand Down Expand Up @@ -473,6 +576,12 @@ def get_model(
if device == "cuda" and not is_model_on_gpu(model):
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")

# Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors)
# Store the MTP layer prefixes on the model for later exclusion from quantization
mtp_layer_prefixes = load_mtp_weights_if_needed(model, ckpt_path)
if mtp_layer_prefixes:
model._mtp_layer_prefixes = mtp_layer_prefixes

return model


Expand Down
20 changes: 20 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_tokenizer,
is_enc_dec,
is_nemotron_vl,
load_mtp_weights_if_needed,
run_nemotron_vl_preview,
)
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -348,6 +349,12 @@ def load_model(args: argparse.Namespace):
)
calibration_only = True

# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
# Store the MTP layer prefixes on the model for later exclusion from quantization
mtp_layer_prefixes = load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path)
if mtp_layer_prefixes:
full_model._mtp_layer_prefixes = mtp_layer_prefixes

model_type = get_model_type(full_model)

device = full_model.device
Expand Down Expand Up @@ -878,6 +885,19 @@ def quantize_main(
KV_QUANT_CFG_CHOICES,
)

# Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92)
# These layers are typically speculative decoding layers that should be exported as-is
mtp_layer_prefixes = getattr(full_model, "_mtp_layer_prefixes", None)
if mtp_layer_prefixes:
import copy

quant_cfg = copy.deepcopy(quant_cfg)
for prefix in mtp_layer_prefixes:
# Add exclusion pattern for this MTP layer (e.g., "*layers.92*")
pattern = f"*{prefix.split('.')[-2]}.{prefix.split('.')[-1]}*"
quant_cfg["quant_cfg"][pattern] = {"enable": False}
print(f"Excluding MTP layer from quantization: {pattern}")

if args.qformat in QUANT_CFG_CHOICES:
mono_quantize(
args,
Expand Down
12 changes: 12 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,18 @@ def _export_transformers_checkpoint(

quant_config = get_quant_config(model, is_modelopt_qlora=is_modelopt_qlora)

# Add MTP layer prefixes to exclude_modules if they were excluded from quantization
# This ensures they appear in quantization_config["ignore"] in config.json
mtp_layer_prefixes = getattr(model, "_mtp_layer_prefixes", None)
if mtp_layer_prefixes:
exclude_modules = quant_config["quantization"].setdefault("exclude_modules", [])
for prefix in mtp_layer_prefixes:
# Add wildcard pattern to exclude all submodules under this MTP layer
pattern = f"{prefix}*"
if pattern not in exclude_modules:
exclude_modules.append(pattern)
print(f"Adding MTP layer to quantization_config ignore: {pattern}")

# Process all quantized modules and export weights
_process_quantized_modules(model, dtype, is_modelopt_qlora)

Expand Down