Skip to content

Commit 7f91341

Browse files
xin3hexinhe3
andauthored
add example for mix-precison recipes (MXFP4+MXFP8) (#2289)
* add example * add llama8b 7bits recipe * fix gsm8k accuracy issue --------- Signed-off-by: xinhe3 <[email protected]> Signed-off-by: He, Xin3 <[email protected]> Co-authored-by: xinhe3 <[email protected]>
1 parent 7fb4bde commit 7f91341

File tree

6 files changed

+8251
-0
lines changed

6 files changed

+8251
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Run
2+
3+
In this examples, you can verify the accuracy on HPU/CUDA device with emulation of MXFP4, MXFP8, NVFP4 and uNVFP4.
4+
5+
## Requirement
6+
7+
```bash
8+
# neural-compressor-pt
9+
pip install neural-compressor-pt==3.6
10+
# auto-round
11+
pip install auto-round==0.8.0
12+
# others
13+
pip install -r requirements.txt
14+
```
15+
16+
## Quantization
17+
18+
### Demo (`MXFP4`, `MXFP8`, `NVFP4`, `uNVFP4`)
19+
20+
```bash
21+
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype MXFP4 --batch_size 8 --accuracy
22+
```
23+
24+
### Mix-precision Quantization (`MXFP4 + MXFP8`)
25+
26+
```bash
27+
# Llama 3.1 8B
28+
python quantize.py \
29+
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
30+
--quantize \
31+
--dtype MXFP4 \
32+
--use_recipe \
33+
--recipe_file recipes/Meta-Llama-3.1-8B-Instruct_7bits.json \
34+
--accuracy \
35+
--batch_size 32
36+
37+
# Llama 3.3 70B
38+
deepspeed --include="localhost:4,5,6,7" --master_port=29500 python quantize.py \
39+
--model_name_or_path meta-llama/Llama-3.3-70B-Instruct/ \
40+
--quantize \
41+
--dtype MXFP4 \
42+
--use_recipe \
43+
--recipe_file recipes/Meta-Llama-3.3-70B-Instruct_5bits.json \
44+
--accuracy \
45+
--batch_size 32
46+
```
47+
48+
> Note:
49+
> 1. Quantization applies `--dtype` for all blocks in the model by removing `--use_recipe`.
50+
> 2. Setting `--quant_lm_head` applies `--dtype` for the lm_head layer.
51+
> 3. Setting `--iters 0` skips AutoRound tuning and uses RTN method.
52+
> 4. The `deepspeed` usage provides quick accuracy verification.
53+
54+
## Inference usage
55+
56+
### NVFP4
57+
NVFP4 is supported by vLLM already, the saved model in this example follows the `llm_compressor` format, please refer to the usage in the public vLLM document.
58+
59+
```bash
60+
# Command to save model:
61+
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype NVFP4 --batch_size 8 --save --save_path opt-125m-nvfp4 --save_format llm_compressor
62+
```
63+
64+
### MXFP4 / MXFP8
65+
MXFP4 and MXFP8 is enabled in a forked vLLM repo, usages as below:
66+
```bash
67+
# Install the forked vLLM
68+
git clone -b cuda-mxfp8-moe --single-branch --quiet https://github.com/yiliu30/vllm-fork.git && cd vllm-fork
69+
USE_CPP=0 VLLM_USE_PRECOMPILED=1 pip install -e . -vvv && cd -
70+
71+
# Command to save model:
72+
python quantize.py \
73+
--model_name_or_path meta-llama/Llama-3.3-70B-Instruct/ \
74+
--quantize \
75+
--iters 0 \
76+
--dtype MXFP4 \
77+
--save_path Llama-3.3-70B-Instruct-MXFP4 \
78+
--save \
79+
--save_format llm_compressor
80+
81+
# Command to inference with vLLM:
82+
CUDA_VISIBLE_DEVICES=0,1 VLLM_USE_V1=0 VLLM_USE_MXFP4_CT_EMULATIONS=1 VLLM_LOGGING_LEVEL=DEBUG \
83+
vllm serve Llama-3.3-70B-Instruct-MXFP4 --tensor-parallel-size=2 --port 7777 --host localhost --trust-remote-code --dtype bfloat16 --enforce-eager
84+
export no_proxy="localhost, 127.0.0.1, ::1"
85+
curl -X POST http://localhost:7777/v1/completions \
86+
-H "Content-Type: application/json" \
87+
-d '{
88+
"model": "/data0/suyue/Llama-3.3-70B-Instruct-MXFP4",
89+
"prompt": "Solve the following math problem step by step: What is 25 + 37? Please answer directly with the result.",
90+
"max_tokens": 100,
91+
"temperature": 0.7,
92+
"top_p": 1.0
93+
}'
94+
```
95+
> Note: To inference with transformers, please save model with `--save_format auto_round` and try `python run_hf_inf.py ${model_name_or_path}`
96+
97+
### MXFP4 + MXFP8
98+
Model with mixed precision is not supported in vLLM, but supported in transformers in `auto-round` format.
99+
100+
```bash
101+
# Command to save model:
102+
python quantize.py \
103+
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
104+
--quantize \
105+
--iters 0 \
106+
--dtype MXFP4 \
107+
--use_recipe \
108+
--recipe_file recipes/Meta-Llama-3.1-8B-Instruct_7bits.json \
109+
--save \
110+
--save_format auto_round \
111+
--save_path Llama-3.1-8B-Instruct-MXFP4-MXFP8-AR
112+
113+
# Command to inference with transformer:
114+
python run_hf_inf.py Llama-3.1-8B-Instruct-MXFP4-MXFP8-AR
115+
```
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
18+
import torch
19+
import transformers
20+
21+
# For reproducibility
22+
torch.manual_seed(42)
23+
torch.use_deterministic_algorithms(True, warn_only=True)
24+
######################## HPU Memory Optimization ###########################
25+
# ensure that unnecessary memory is released during quantization.
26+
os.environ.setdefault("PT_HPU_LAZY_MODE", "1")
27+
os.environ.setdefault("PT_HPU_WEIGHT_SHARING", "0")
28+
if int(os.getenv("WORLD_SIZE", "0")) > 0:
29+
os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0")
30+
os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true")
31+
from neural_compressor.torch.utils import is_hpex_available, world_size
32+
from auto_round import AutoRound
33+
34+
if is_hpex_available():
35+
import habana_frameworks.torch.core as htcore
36+
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
37+
38+
htcore.hpu_set_env()
39+
############################################################################
40+
41+
42+
def initialize_model_and_tokenizer(model_name_or_path):
43+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
44+
config = transformers.AutoConfig.from_pretrained(model_name_or_path)
45+
# using memory mapping with torch_dtype=config.torch_dtype
46+
model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=config.torch_dtype)
47+
# shard model for multi-cards and enable hpu graph
48+
49+
if world_size > 1:
50+
ds_inference_kwargs = {
51+
"dtype": config.torch_dtype,
52+
"tensor_parallel": {"tp_size": world_size},
53+
}
54+
import deepspeed
55+
56+
ds_model = deepspeed.init_inference(model, **ds_inference_kwargs)
57+
model = ds_model.module
58+
model.eval()
59+
return model, tokenizer
60+
61+
62+
if __name__ == "__main__":
63+
parser = argparse.ArgumentParser(
64+
description="Habana FP8 quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
65+
)
66+
parser.add_argument(
67+
"--model_name_or_path", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct", help="model name or path"
68+
)
69+
parser.add_argument("--dtype", type=str, default="MXFP4", choices=["MXFP4", "MXFP8", "NVFP4", "NVFP4+", "uNVFP4"], help="data type")
70+
parser.add_argument("--quantize", action="store_true", help="whether to quantize model")
71+
parser.add_argument("--device_map", type=str, default=None, help="device map for model")
72+
parser.add_argument("--use_recipe", action="store_true", help="whether to use recipe to quantize model")
73+
parser.add_argument("--recipe_file", type=str, default="recipes/Meta-Llama-3.1-8B-Instruct_6bits.json", help="path of recipe file")
74+
parser.add_argument("--iters", default=200, type=int, help="iters for autoround.")
75+
parser.add_argument("--seqlen", default=2048, type=int, help="sequence length for autoround.")
76+
parser.add_argument("--nsamples", default=128, type=int, help="number of samples for autoround.")
77+
parser.add_argument("--save", action="store_true", help="whether to save the quantized model")
78+
parser.add_argument("--save_path", type=str, default="saved_results", help="path to save the quantized model")
79+
parser.add_argument("--save_format", type=str, default="auto_round", help="format to save the quantized model")
80+
parser.add_argument("--quant_lm_head", action="store_true", help="whether to quantize lm_head")
81+
parser.add_argument("--accuracy", action="store_true", help="accuracy measurement")
82+
parser.add_argument("--local_rank", type=int, default=0, metavar="N", help="Local process rank.")
83+
parser.add_argument("--batch_size", default=32, type=int, help="batch size for accuracy evaluation.")
84+
parser.add_argument(
85+
"--tasks",
86+
type=str,
87+
nargs="+",
88+
default=[
89+
"piqa",
90+
"hellaswag",
91+
"mmlu",
92+
"winogrande",
93+
"lambada_openai",
94+
],
95+
help="tasks for accuracy validation, text-generation and code-generation tasks are different.",
96+
)
97+
parser.add_argument("--limit", type=int, default=None, help="number of samples for accuracy evaluation")
98+
args = parser.parse_args()
99+
100+
print("Target data type:", args.dtype)
101+
102+
model, tokenizer = initialize_model_and_tokenizer(args.model_name_or_path)
103+
device="hpu" if is_hpex_available() else "cuda"
104+
105+
if args.quantize:
106+
autoround_dtype_mapping = {
107+
"MXFP4": "mx_fp4",
108+
"MXFP8": "mx_fp8",
109+
"NVFP4": "nv_fp4",
110+
"uNVFP4": "fp4_v2",
111+
"NVFP4+": "fp4_v2",
112+
}
113+
args.dtype = autoround_dtype_mapping[args.dtype]
114+
if args.quant_lm_head:
115+
lm_head_config = {
116+
"group_size": 32 if "mx" in args.dtype else 16,
117+
"data_type": args.dtype,
118+
"act_data_type": "fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype,
119+
}
120+
layer_config = {"lm_head": lm_head_config}
121+
122+
autoround = AutoRound(
123+
model,
124+
tokenizer,
125+
device=device,
126+
device_map="tp" if world_size > 1 else args.device_map,
127+
iters=args.iters,
128+
seqlen=args.seqlen,
129+
nsamples=args.nsamples,
130+
low_gpu_mem_usage=True,
131+
group_size=32 if "mx" in args.dtype else 16,
132+
data_type=args.dtype,
133+
act_data_type="fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype,
134+
layer_config=layer_config if args.quant_lm_head else None,
135+
)
136+
137+
if args.use_recipe:
138+
############ load recipe results (MXFP4 + MXFP8) ############
139+
def load_recipe_results(file_path):
140+
import json
141+
with open(file_path, "r") as f:
142+
return json.load(f)
143+
144+
layer_config = load_recipe_results(args.recipe_file)
145+
if args.quant_lm_head:
146+
mxfp8_config = {
147+
"bits": 8,
148+
"group_size": 32,
149+
"data_type": "mx_fp8",
150+
"act_data_type": "mx_fp8",
151+
}
152+
# ensure lm_head is quantized with mxfp8_config
153+
layer_config.update({"lm_head": mxfp8_config})
154+
print("In recipe mode, lm_head is quantized with MXFP8.")
155+
autoround.layer_config = layer_config
156+
157+
autoround.quantize()
158+
model = autoround.model
159+
160+
if args.accuracy:
161+
# set dtype to BF16 for HPU inference performance
162+
model = model.to(torch.bfloat16)
163+
model = model.eval().to(device)
164+
if is_hpex_available():
165+
# HPU needs padding to buckets for better performance
166+
# Generation tasks, such as gsm8k and mmlu-pro, may get OOM.
167+
model = wrap_in_hpu_graph(model)
168+
htcore.hpu_inference_initialize(model, mark_only_scales_as_const=True)
169+
from neural_compressor.evaluation.lm_eval import LMEvalParser, evaluate
170+
171+
tasks = ",".join(args.tasks)
172+
eval_args = LMEvalParser(
173+
model="hf",
174+
user_model=model,
175+
tokenizer=tokenizer,
176+
batch_size=args.batch_size,
177+
tasks=tasks,
178+
device="hpu",
179+
pad_to_buckets=True,
180+
limit=args.limit,
181+
add_bos_token=True,
182+
)
183+
results = evaluate(eval_args)
184+
torch.hpu.synchronize()
185+
all_accuracy = {}
186+
for task_name, task_results in results["results"].items():
187+
if task_name in ["hellaswag", "lambada_openai", "piqa", "winogrande", "mmlu"]:
188+
accu = task_results["acc,none"]
189+
all_accuracy[task_name] = accu
190+
print(f"Accuracy for {task_name}: {accu:.4f}")
191+
print(f"Overall accuracy: {sum(all_accuracy.values())/len(all_accuracy):.4f}")
192+
else:
193+
# CUDA evaluation support all tasks.
194+
# gsm8k requires add_bos_token=False for better accuracy for llama model.
195+
# model = torch.compile(model)
196+
args.tasks = ["piqa", "hellaswag", "mmlu", "gsm8k"]
197+
all_accuracy = {}
198+
test_gsm8k = False
199+
test_normal = False
200+
if "gsm8k" in args.tasks:
201+
test_gsm8k = True
202+
args.tasks.remove("gsm8k")
203+
if args.tasks:
204+
test_normal = True
205+
import lm_eval
206+
from lm_eval.models.huggingface import HFLM
207+
208+
########################## gms8k (ahead of normal tasks) #########################
209+
if test_gsm8k:
210+
lm = HFLM(
211+
pretrained=model,
212+
tokenizer=tokenizer,
213+
add_bos_token=False,
214+
batch_size=args.batch_size,
215+
)
216+
results_gsm8k = lm_eval.simple_evaluate(
217+
lm,
218+
tasks=["gsm8k"],
219+
limit=args.limit,
220+
)
221+
for task_name, task_results in results_gsm8k["results"].items():
222+
accu = task_results["exact_match,strict-match"]
223+
all_accuracy[task_name] = accu
224+
########################## gms8k end #########################
225+
if test_normal:
226+
lm = HFLM(
227+
pretrained=model,
228+
tokenizer=tokenizer,
229+
add_bos_token=True,
230+
batch_size=args.batch_size,
231+
)
232+
results = lm_eval.simple_evaluate(
233+
lm,
234+
tasks=args.tasks,
235+
limit=args.limit,
236+
)
237+
for task_name, task_results in results["results"].items():
238+
if task_name in ["hellaswag", "lambada_openai", "piqa", "winogrande", "mmlu"]:
239+
accu = task_results["acc,none"]
240+
all_accuracy[task_name] = accu
241+
for task_name, accu in all_accuracy.items():
242+
print(f"Accuracy for {task_name}: {accu:.4f}")
243+
print(f"Overall accuracy: {sum(all_accuracy.values())/len(all_accuracy):.4f}")
244+
245+
if args.save:
246+
if args.dtype == "nv_fp4":
247+
# using llm_compressor format to save nv_fp4 model
248+
autoround.save_quantized(args.save_path, format=args.save_format)
249+
else:
250+
# using auto_round format to save mx_fp4 and mx_fp8 model
251+
if world_size > 1:
252+
print(f"Suggest to save model without sharding for better reload experience.")
253+
print(f"Setting`--device_map 0,1,2,3` provides pipeline parallel instead of deepspeed tensor parallel.")
254+
output_dir = args.save_path + "/" + args.local_rank + "_" + args.world_size
255+
autoround.save_quantized(output_dir, format=args.save_format)
256+
else:
257+
autoround.save_quantized(args.save_path, format=args.save_format)
258+
print(f"Quantized model in {args.save_format} format is saved to {args.save_path}")

0 commit comments

Comments
 (0)