From 116de1de0b4068f746bb9040f6a1d74ee6133a98 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 15 Jan 2026 23:06:11 -0800 Subject: [PATCH 1/9] add support for MTP loading for GLM-4.7 in PTQ flow Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 78 +++++++++++++++++++++++++++++++ examples/llm_ptq/hf_ptq.py | 5 ++ 2 files changed, 83 insertions(+) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index aad29fc97..a90c14f27 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -314,6 +314,81 @@ def get_processor( return None +def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: + """Load MTP weights from separate safetensors if needed (e.g., GLM-4.7). + + Some models store additional layers in separate safetensors files with non-standard + names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these + files even though they're referenced in model.safetensors.index.json. + + This function detects such cases and explicitly loads the missing weights. + + Args: + model: The loaded model that may be missing weights + model_path: Path to the model directory + + Returns: + True if additional weights were loaded, False otherwise + """ + model_path = Path(model_path) + index_file = model_path / "model.safetensors.index.json" + + if not index_file.exists(): + return False + + # Load the index to find all referenced safetensors files + with open(index_file) as f: + index = json.load(f) + + # Find all unique safetensors files referenced + all_files = set(index["weight_map"].values()) + + # Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern) + import re + standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors") + non_standard_files = [f for f in all_files if not standard_pattern.match(f)] + + if not non_standard_files: + return False + + # Check which non-standard files exist and have missing weights + model_state = model.state_dict() + total_loaded = 0 + + for filename in non_standard_files: + filepath = model_path / filename + if not filepath.exists(): + continue + + # Find keys that should be in this file + expected_keys = [k for k, v in index["weight_map"].items() if v == filename] + + # Check which are missing from the model + missing_keys = [k for k in expected_keys if k not in model_state] + + if not missing_keys: + continue + + print(f"Loading {len(missing_keys)} missing weights from {filename}...") + + # Load the weights + weights = load_file(str(filepath)) + weights_to_load = {k: v for k, v in weights.items() if k in missing_keys} + + # Load into model + missing, unexpected = model.load_state_dict(weights_to_load, strict=False) + total_loaded += len(weights_to_load) + + if missing: + print(f" Warning: {len(missing)} keys still missing after loading {filename}") + + if total_loaded > 0: + print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files") + return True + + return False + + def get_dtype(dtype): if dtype == "bf16": dtype = torch.bfloat16 @@ -473,6 +548,9 @@ def get_model( if device == "cuda" and not is_model_on_gpu(model): print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") + # Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors) + load_mtp_weights_if_needed(model, ckpt_path) + return model diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index e32d0dae8..971cc90b5 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -348,6 +348,11 @@ def load_model(args: argparse.Namespace): ) calibration_only = True + # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) + from example_utils import load_mtp_weights_if_needed + + load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path) + model_type = get_model_type(full_model) device = full_model.device From 01f0b050724386969d0d41faec9325d1d7e4fae3 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 15 Jan 2026 23:06:44 -0800 Subject: [PATCH 2/9] add support for MTP loading for GLM-4.7 in PTQ flow Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index a90c14f27..056b6c578 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -345,6 +345,7 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: # Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern) import re + standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors") non_standard_files = [f for f in all_files if not standard_pattern.match(f)] From e13d9a2c441f93e614f5e3858f3f445a691f5a9a Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 15 Jan 2026 23:08:21 -0800 Subject: [PATCH 3/9] add support for MTP loading for GLM-4.7 in PTQ flow Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 056b6c578..f6bfdab1d 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -16,7 +16,9 @@ import copy import glob import inspect +import json import os +import re import shutil import sys import warnings @@ -27,6 +29,7 @@ import transformers from accelerate import infer_auto_device_map, init_empty_weights from accelerate.utils import get_max_memory +from safetensors.torch import load_file from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -344,7 +347,6 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: all_files = set(index["weight_map"].values()) # Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern) - import re standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors") non_standard_files = [f for f in all_files if not standard_pattern.match(f)] From 0ad3f88d3a90d48bf2f7f87a4c5371f1e63c3f75 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 3 Feb 2026 15:32:19 -0800 Subject: [PATCH 4/9] minor Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/hf_ptq.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 971cc90b5..ec645556b 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -31,6 +31,7 @@ get_tokenizer, is_enc_dec, is_nemotron_vl, + load_mtp_weights_if_needed, run_nemotron_vl_preview, ) from torch.utils.data import DataLoader @@ -349,8 +350,6 @@ def load_model(args: argparse.Namespace): calibration_only = True # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) - from example_utils import load_mtp_weights_if_needed - load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path) model_type = get_model_type(full_model) From 2774a2c5d0512feeb50044679a957856297aa8ab Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 3 Feb 2026 15:51:01 -0800 Subject: [PATCH 5/9] skip MTP layers from quantization and export as-is Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 44 +++++++++++++++++++++++++------ examples/llm_ptq/hf_ptq.py | 18 ++++++++++++- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index f6bfdab1d..4c0175961 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -317,7 +317,7 @@ def get_processor( return None -def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: +def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]: """Load MTP weights from separate safetensors if needed (e.g., GLM-4.7). Some models store additional layers in separate safetensors files with non-standard @@ -331,13 +331,16 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: model_path: Path to the model directory Returns: - True if additional weights were loaded, False otherwise + List of layer prefixes that were loaded from non-standard safetensors files. + These layers should typically be excluded from quantization. + Empty list if no additional weights were loaded. """ model_path = Path(model_path) index_file = model_path / "model.safetensors.index.json" + mtp_layer_prefixes: list[str] = [] if not index_file.exists(): - return False + return mtp_layer_prefixes # Load the index to find all referenced safetensors files with open(index_file) as f: @@ -347,12 +350,11 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: all_files = set(index["weight_map"].values()) # Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern) - standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors") non_standard_files = [f for f in all_files if not standard_pattern.match(f)] if not non_standard_files: - return False + return mtp_layer_prefixes # Check which non-standard files exist and have missing weights model_state = model.state_dict() @@ -370,10 +372,31 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: missing_keys = [k for k in expected_keys if k not in model_state] if not missing_keys: + # Even if weights are loaded, record the layer prefixes for exclusion + # Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight") + for key in expected_keys: + # Extract layer prefix like "model.layers.92" or "layers.92" + parts = key.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): + prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" + if prefix not in mtp_layer_prefixes: + mtp_layer_prefixes.append(prefix) + break continue print(f"Loading {len(missing_keys)} missing weights from {filename}...") + # Extract unique layer prefixes for exclusion from quantization + for key in missing_keys: + parts = key.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): + prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" + if prefix not in mtp_layer_prefixes: + mtp_layer_prefixes.append(prefix) + break + # Load the weights weights = load_file(str(filepath)) weights_to_load = {k: v for k, v in weights.items() if k in missing_keys} @@ -387,9 +410,11 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> bool: if total_loaded > 0: print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files") - return True - return False + if mtp_layer_prefixes: + print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}") + + return mtp_layer_prefixes def get_dtype(dtype): @@ -552,7 +577,10 @@ def get_model( print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") # Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors) - load_mtp_weights_if_needed(model, ckpt_path) + # Store the MTP layer prefixes on the model for later exclusion from quantization + mtp_layer_prefixes = load_mtp_weights_if_needed(model, ckpt_path) + if mtp_layer_prefixes: + model._mtp_layer_prefixes = mtp_layer_prefixes return model diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index ec645556b..a5af5e97d 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -350,7 +350,10 @@ def load_model(args: argparse.Namespace): calibration_only = True # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) - load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path) + # Store the MTP layer prefixes on the model for later exclusion from quantization + mtp_layer_prefixes = load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path) + if mtp_layer_prefixes: + full_model._mtp_layer_prefixes = mtp_layer_prefixes model_type = get_model_type(full_model) @@ -882,6 +885,19 @@ def quantize_main( KV_QUANT_CFG_CHOICES, ) + # Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92) + # These layers are typically speculative decoding layers that should be exported as-is + mtp_layer_prefixes = getattr(full_model, "_mtp_layer_prefixes", None) + if mtp_layer_prefixes: + import copy + + quant_cfg = copy.deepcopy(quant_cfg) + for prefix in mtp_layer_prefixes: + # Add exclusion pattern for this MTP layer (e.g., "*layers.92*") + pattern = f"*{prefix.split('.')[-2]}.{prefix.split('.')[-1]}*" + quant_cfg["quant_cfg"][pattern] = {"enable": False} + print(f"Excluding MTP layer from quantization: {pattern}") + if args.qformat in QUANT_CFG_CHOICES: mono_quantize( args, From 39c6195d103dd5128b4ac75d458a498f4b4f943b Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 3 Feb 2026 15:59:58 -0800 Subject: [PATCH 6/9] add MTP modules in excluded/ignore modules in config Signed-off-by: Zhiyu Cheng --- modelopt/torch/export/unified_export_hf.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 011af533d..2c250fbaa 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -694,6 +694,18 @@ def _export_transformers_checkpoint( quant_config = get_quant_config(model, is_modelopt_qlora=is_modelopt_qlora) + # Add MTP layer prefixes to exclude_modules if they were excluded from quantization + # This ensures they appear in quantization_config["ignore"] in config.json + mtp_layer_prefixes = getattr(model, "_mtp_layer_prefixes", None) + if mtp_layer_prefixes: + exclude_modules = quant_config["quantization"].setdefault("exclude_modules", []) + for prefix in mtp_layer_prefixes: + # Add wildcard pattern to exclude all submodules under this MTP layer + pattern = f"{prefix}*" + if pattern not in exclude_modules: + exclude_modules.append(pattern) + print(f"Adding MTP layer to quantization_config ignore: {pattern}") + # Process all quantized modules and export weights _process_quantized_modules(model, dtype, is_modelopt_qlora) From f060d94592da02a32d9e75eb82ad6c36d3ee50a3 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 3 Feb 2026 16:18:15 -0800 Subject: [PATCH 7/9] update changelog Signed-off-by: Zhiyu Cheng --- CHANGELOG.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 99932794a..9eccd173e 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,8 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow. - Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model. - Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models. +- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is. +- Add support for image-text data calibration in PTQ for Nemotron VL models. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ From 5e42017bf2d9ab1f5f4a006470dc7318a845e9e3 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 3 Feb 2026 16:20:56 -0800 Subject: [PATCH 8/9] update readme Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 859551809..9d85c0998 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -109,6 +109,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http | QWen3 MOE, Next 6 | ✅ | - | - | - | ✅ | | QwQ | ✅ | - | - | - | ✅ | | DeepSeek V3, R1, V3.1, V3.27 | - | - | - | - | ✅ | +| GLM-4.78 | ✅ | - | - | - | ✅ | | Kimi K2 | - | - | - | - | ✅ | | T5 | ✅ | ✅ | ✅ | ✅ | - | | Whisper | ✅ | ❌ | ❌ | ❌ | - | @@ -121,7 +122,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http > *4.For some models, KV cache quantization may result in a higher accuracy penalty.* \ > *5.A selective set of the popular models are internally tested. The actual model support list may be longer. NVFP4 inference requires Blackwell GPUs and TensorRT-LLM v0.17 or later* \ > *6.Some models currently support export to HF format only.* \ -> *7.[PTQ for DeepSeek](../deepseek/README.md)* +> *7.[PTQ for DeepSeek](../deepseek/README.md)* \ +> *8.GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.* > *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead.* From 3bda6d816b0841baa24ee312454dd3493312d8f6 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Wed, 4 Feb 2026 11:06:19 -0800 Subject: [PATCH 9/9] minor Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 4c0175961..e8f5575d3 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -397,8 +397,8 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[ mtp_layer_prefixes.append(prefix) break - # Load the weights - weights = load_file(str(filepath)) + # Load the weights to CPU first, load_state_dict will handle device placement + weights = load_file(str(filepath), device="cpu") weights_to_load = {k: v for k, v in weights.items() if k in missing_keys} # Load into model