Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Run

In this examples, you can verify the accuracy on HPU/CUDA device with emulation of MXFP4, MXFP8, NVFP4 and uNVFP4.

## Requirement

```bash
# neural-compressor-pt
pip install neural-compressor-pt==3.6
# auto-round
pip install auto-round==0.8.0
```

## Quantization

### Demo (`MXFP4`, `MXFP8`, `NVFP4`, `uNVFP4`)

```bash
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype MXFP4 --batch_size 8 --accuracy
```

### Mix-precision Quantization (`MXFP4 + MXFP8`)

```bash
# Llama 3.1 8B
python quantize.py \
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
--quantize \
--dtype MXFP4 \
--use_recipe \
--recipe_file recipes/Meta-Llama-3.1-8B-Instruct_7bits.json \
--accuracy \
--batch_size 32

# Llama 3.3 70B
deepspeed --include="localhost:4,5,6,7" --master_port=29500 python quantize.py \
--model_name_or_path meta-llama/Llama-3.3-70B-Instruct/ \
--quantize \
--dtype MXFP4 \
--use_recipe \
--recipe_file recipes/Meta-Llama-3.3-70B-Instruct_5bits.json \
--accuracy \
--batch_size 32
```

> Note:
> 1. Quantization applies `--dtype` for all blocks in the model by removing `--use_recipe`.
> 2. Setting `--quant_lm_head` applies `--dtype` for the lm_head layer.
> 3. Setting `--iters 0` skips AutoRound tuning and uses RTN method.
> 4. The `deepspeed` usage provides quick accuracy verification.

## Inference usage

### NVFP4
NVFP4 is supported by vLLM already, the saved model in this example follows the `llm_compressor` format, please refer to the usage in the public vLLM document.

```bash
# Command to save model:
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype NVFP4 --batch_size 8 --save --save_path opt-125m-nvfp4 --save_format llm_compressor
```

### MXFP4 / MXFP8
MXFP4 and MXFP8 is enabled in a forked vLLM repo, usages as below:
```bash
# Install the forked vLLM
git clone -b cuda-mxfp8-moe --single-branch --quiet https://github.com/yiliu30/vllm-fork.git && cd vllm-fork
USE_CPP=0 VLLM_USE_PRECOMPILED=1 pip install -e . -vvv && cd -

# Command to save model:
python quantize.py \
--model_name_or_path meta-llama/Llama-3.3-70B-Instruct/ \
--quantize \
--iters 0 \
--dtype MXFP4 \
--save_path Llama-3.3-70B-Instruct-MXFP4 \
--save \
--save_format llm_compressor

# Command to inference with vLLM:
CUDA_VISIBLE_DEVICES=0,1 VLLM_USE_V1=0 VLLM_USE_MXFP4_CT_EMULATIONS=1 VLLM_LOGGING_LEVEL=DEBUG \
vllm serve Llama-3.3-70B-Instruct-MXFP4 --tensor-parallel-size=2 --port 7777 --host localhost --trust-remote-code --dtype bfloat16 --enforce-eager
export no_proxy="localhost, 127.0.0.1, ::1"
curl -X POST http://localhost:7777/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/data0/suyue/Llama-3.3-70B-Instruct-MXFP4",
"prompt": "Solve the following math problem step by step: What is 25 + 37? Please answer directly with the result.",
"max_tokens": 100,
"temperature": 0.7,
"top_p": 1.0
}'
```
> Note: To inference with transformers, please save model with `--save_format auto_round` and try `python run_hf_inf.py ${model_name_or_path}`

### MXFP4 + MXFP8
Model with mixed precision is not supported in vLLM, but supported in transformers in `auto-round` format.

```bash
# Command to save model:
python quantize.py \
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
--quantize \
--iters 0 \
--dtype MXFP4 \
--use_recipe \
--recipe_file recipes/Meta-Llama-3.1-8B-Instruct_7bits.json \
--save \
--save_format auto_round \
--save_path Llama-3.1-8B-Instruct-MXFP4-MXFP8-AR

# Command to inference with transformer:
python run_hf_inf.py Llama-3.1-8B-Instruct-MXFP4-MXFP8-AR
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import torch
import transformers

# For reproducibility
torch.manual_seed(42)
torch.use_deterministic_algorithms(True, warn_only=True)
######################## HPU Memory Optimization ###########################
# ensure that unnecessary memory is released during quantization.
os.environ.setdefault("PT_HPU_LAZY_MODE", "1")
os.environ.setdefault("PT_HPU_WEIGHT_SHARING", "0")
if int(os.getenv("WORLD_SIZE", "0")) > 0:
os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0")
os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true")
from neural_compressor.torch.utils import is_hpex_available, world_size
from auto_round import AutoRound

if is_hpex_available():
import habana_frameworks.torch.core as htcore
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

htcore.hpu_set_env()
############################################################################


def initialize_model_and_tokenizer(model_name_or_path):
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
config = transformers.AutoConfig.from_pretrained(model_name_or_path)
# using memory mapping with torch_dtype=config.torch_dtype
model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=config.torch_dtype)
# shard model for multi-cards and enable hpu graph

if world_size > 1:
ds_inference_kwargs = {
"dtype": config.torch_dtype,
"tensor_parallel": {"tp_size": world_size},
}
import deepspeed

ds_model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = ds_model.module
model.eval()
return model, tokenizer


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Habana FP8 quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model_name_or_path", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct", help="model name or path"
)
parser.add_argument("--dtype", type=str, default="MXFP4", choices=["MXFP4", "MXFP8", "NVFP4", "NVFP4+", "uNVFP4"], help="data type")
parser.add_argument("--quantize", action="store_true", help="whether to quantize model")
parser.add_argument("--device_map", type=str, default=None, help="device map for model")
parser.add_argument("--use_recipe", action="store_true", help="whether to use recipe to quantize model")
parser.add_argument("--recipe_file", type=str, default="recipes/Meta-Llama-3.1-8B-Instruct_6bits.json", help="path of recipe file")
parser.add_argument("--iters", default=200, type=int, help="iters for autoround.")
parser.add_argument("--seqlen", default=2048, type=int, help="sequence length for autoround.")
parser.add_argument("--nsamples", default=128, type=int, help="number of samples for autoround.")
parser.add_argument("--save", action="store_true", help="whether to save the quantized model")
parser.add_argument("--save_path", type=str, default="saved_results", help="path to save the quantized model")
parser.add_argument("--save_format", type=str, default="auto_round", help="format to save the quantized model")
parser.add_argument("--quant_lm_head", action="store_true", help="whether to quantize lm_head")
parser.add_argument("--accuracy", action="store_true", help="accuracy measurement")
parser.add_argument("--local_rank", type=int, default=0, metavar="N", help="Local process rank.")
parser.add_argument("--batch_size", default=32, type=int, help="batch size for accuracy evaluation.")
parser.add_argument(
"--tasks",
type=str,
nargs="+", # 接受一个或多个字符串作为列表
default=[
"piqa",
"hellaswag",
"mmlu",
"winogrande",
"lambada_openai",
], # 默认值
help="tasks for accuracy validation, text-generation and code-generation tasks are different.",
)
parser.add_argument("--limit", type=int, default=None, help="number of samples for accuracy evaluation")
args = parser.parse_args()

print("Target data type:", args.dtype)

model, tokenizer = initialize_model_and_tokenizer(args.model_name_or_path)
device="hpu" if is_hpex_available() else "cuda"

if args.quantize:
autoround_dtype_mapping = {
"MXFP4": "mx_fp4",
"MXFP8": "mx_fp8",
"NVFP4": "nv_fp4",
"uNVFP4": "fp4_v2",
"NVFP4+": "fp4_v2",
}
args.dtype = autoround_dtype_mapping[args.dtype]
if args.quant_lm_head:
lm_head_config = {
"group_size": 32 if "mx" in args.dtype else 16,
"data_type": args.dtype,
"act_data_type": "fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype,
}
layer_config = {"lm_head": lm_head_config}

autoround = AutoRound(
model,
tokenizer,
device=device,
device_map="tp" if world_size > 1 else args.device_map,
iters=args.iters,
seqlen=args.seqlen,
nsamples=args.nsamples,
low_gpu_mem_usage=True,
group_size=32 if "mx" in args.dtype else 16,
data_type=args.dtype,
act_data_type="fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype,
layer_config=layer_config if args.quant_lm_head else None,
)

if args.use_recipe:
############ load recipe results (MXFP4 + MXFP8) ############
def load_recipe_results(file_path):
import json
with open(file_path, "r") as f:
return json.load(f)

layer_config = load_recipe_results(args.recipe_file)
if args.quant_lm_head:
mxfp8_config = {
"bits": 8,
"group_size": 32,
"data_type": "mx_fp8",
"act_data_type": "mx_fp8",
}
# ensure lm_head is quantized with mxfp8_config
layer_config.update({"lm_head": mxfp8_config})
print("In recipe mode, lm_head is quantized with MXFP8.")
autoround.layer_config = layer_config

autoround.quantize()
model = autoround.model

if args.accuracy:
# set dtype to BF16 for HPU inference performance
model = model.to(torch.bfloat16)
model = model.eval().to(device)
if is_hpex_available():
# HPU needs padding to buckets for better performance
# Generation tasks, such as gsm8k and mmlu-pro, may get OOM.
model = wrap_in_hpu_graph(model)
htcore.hpu_inference_initialize(model, mark_only_scales_as_const=True)
from neural_compressor.evaluation.lm_eval import LMEvalParser, evaluate

tasks = ",".join(args.tasks)
eval_args = LMEvalParser(
model="hf",
user_model=model,
tokenizer=tokenizer,
batch_size=args.batch_size,
tasks=tasks,
device="hpu",
pad_to_buckets=True,
limit=args.limit,
add_bos_token=True,
)
results = evaluate(eval_args)
torch.hpu.synchronize()
all_accuracy = {}
for task_name, task_results in results["results"].items():
if task_name in ["hellaswag", "lambada_openai", "piqa", "winogrande", "mmlu"]:
accu = task_results["acc,none"]
all_accuracy[task_name] = accu
print(f"Accuracy for {task_name}: {accu:.4f}")
print(f"Overall accuracy: {sum(all_accuracy.values())/len(all_accuracy):.4f}")
else:
# CUDA evaluation support all tasks.
# gsm8k requires add_bos_token=False for better accuracy for llama model.
# model = torch.compile(model)
args.tasks = ["piqa", "hellaswag", "mmlu", "gsm8k"]
all_accuracy = {}
test_gsm8k = False
test_normal = False
if "gsm8k" in args.tasks:
test_gsm8k = True
args.tasks.remove("gsm8k")
if args.tasks:
test_normal = True
import lm_eval
from lm_eval.models.huggingface import HFLM

########################## gms8k (ahead of normal tasks) #########################
if test_gsm8k:
lm = HFLM(
pretrained=model,
tokenizer=tokenizer,
add_bos_token=False,
batch_size=args.batch_size,
)
results_gsm8k = lm_eval.simple_evaluate(
lm,
tasks=["gsm8k"],
limit=args.limit,
)
for task_name, task_results in results_gsm8k["results"].items():
accu = task_results["exact_match,strict-match"]
all_accuracy[task_name] = accu
########################## gms8k end #########################
if test_normal:
lm = HFLM(
pretrained=model,
tokenizer=tokenizer,
add_bos_token=True,
batch_size=args.batch_size,
)
results = lm_eval.simple_evaluate(
lm,
tasks=args.tasks,
limit=args.limit,
)
for task_name, task_results in results["results"].items():
if task_name in ["hellaswag", "lambada_openai", "piqa", "winogrande", "mmlu"]:
accu = task_results["acc,none"]
all_accuracy[task_name] = accu
for task_name, accu in all_accuracy.items():
print(f"Accuracy for {task_name}: {accu:.4f}")
print(f"Overall accuracy: {sum(all_accuracy.values())/len(all_accuracy):.4f}")

if args.save:
if args.dtype == "nv_fp4":
# using llm_compressor format to save nv_fp4 model
autoround.save_quantized(args.save_path, format=args.save_format)
else:
# using auto_round format to save mx_fp4 and mx_fp8 model
if world_size > 1:
print(f"Suggest to save model without sharding for better reload experience.")
print(f"Setting`--device_map 0,1,2,3` provides pipeline parallel instead of deepspeed tensor parallel.")
output_dir = args.save_path + "/" + args.local_rank + "_" + args.world_size
autoround.save_quantized(output_dir, format=args.save_format)
else:
autoround.save_quantized(args.save_path, format=args.save_format)
print(f"Quantized model in {args.save_format} format is saved to {args.save_path}")
Loading