-
Notifications
You must be signed in to change notification settings - Fork 281
add example for mix-precison recipes (MXFP4+MXFP8) #2289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
cce026c
add example
xinhe3 37804c1
fix tp device issue
xinhe3 5c20e1f
update readme
xinhe3 01e8266
remove num_fewshot
xinhe3 e453ac1
change recipe
xinhe3 b31c359
add llama8b 7bits recipe
xinhe3 afd41a2
fix gsm8k accuracy issue
xinhe3 ea9ce69
change batch_size as default 8
xin3he a13eef1
update script
xin3he dc5f7cb
update example
xin3he 458ffab
update for reproducibility
xin3he 7bed59c
remove print
xin3he 8600f2f
update recipe of 8b and remove qwen
xin3he dd5a4df
Merge branch 'master' into xinhe/mx_recipe
xin3he dd1e95b
Update README.md
xin3he e1be060
update per review comments
xin3he 9c1c2d8
update per review suggestions
xinhe3 f1aab7e
Update quantize.py
xin3he 2555b01
add save format
xin3he 155b768
update document for pepeline parallel
xin3he d02ec63
warning only
xin3he d5095e8
update with Suyue's validation
xin3he ef3221f
update requirement
xin3he 26a0243
update per review comments
xin3he File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
113 changes: 113 additions & 0 deletions
113
...h/nlp/huggingface_models/language-modeling/quantization/mix-precision/README.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Run | ||
|
||
In this examples, you can verify the accuracy on HPU/CUDA device with emulation of MXFP4, MXFP8, NVFP4 and uNVFP4. | ||
|
||
## Requirement | ||
|
||
```bash | ||
# neural-compressor-pt | ||
pip install neural-compressor-pt==3.6 | ||
xin3he marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# auto-round | ||
pip install auto-round==0.8.0 | ||
``` | ||
xin3he marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
## Quantization | ||
|
||
### Demo (`MXFP4`, `MXFP8`, `NVFP4`, `uNVFP4`) | ||
|
||
```bash | ||
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype MXFP4 --batch_size 8 --accuracy | ||
``` | ||
|
||
### Mix-precision Quantization (`MXFP4 + MXFP8`) | ||
|
||
```bash | ||
# Llama 3.1 8B | ||
python quantize.py \ | ||
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \ | ||
--quantize \ | ||
--dtype MXFP4 \ | ||
--use_recipe \ | ||
--recipe_file recipes/Meta-Llama-3.1-8B-Instruct_7bits.json \ | ||
--accuracy \ | ||
xin3he marked this conversation as resolved.
Show resolved
Hide resolved
|
||
--batch_size 32 | ||
|
||
# Llama 3.3 70B | ||
deepspeed --include="localhost:4,5,6,7" --master_port=29500 python quantize.py \ | ||
--model_name_or_path meta-llama/Llama-3.3-70B-Instruct/ \ | ||
--quantize \ | ||
--dtype MXFP4 \ | ||
--use_recipe \ | ||
--recipe_file recipes/Meta-Llama-3.3-70B-Instruct_5bits.json \ | ||
--accuracy \ | ||
--batch_size 32 | ||
``` | ||
|
||
> Note: | ||
> 1. Quantization applies `--dtype` for all blocks in the model by removing `--use_recipe`. | ||
> 2. Setting `--quant_lm_head` applies `--dtype` for the lm_head layer. | ||
> 3. Setting `--iters 0` skips AutoRound tuning and uses RTN method. | ||
> 4. The `deepspeed` usage provides quick accuracy verification. | ||
|
||
## Inference usage | ||
|
||
### NVFP4 | ||
NVFP4 is supported by vLLM already, the saved model in this example follows the `llm_compressor` format, please refer to the usage in the public vLLM document. | ||
|
||
```bash | ||
# Command to save model: | ||
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype NVFP4 --batch_size 8 --save --save_path opt-125m-nvfp4 --save_format llm_compressor | ||
``` | ||
|
||
### MXFP4 / MXFP8 | ||
MXFP4 and MXFP8 is enabled in a forked vLLM repo, usages as below: | ||
```bash | ||
# Install the forked vLLM | ||
git clone -b cuda-mxfp8-moe --single-branch --quiet https://github.com/yiliu30/vllm-fork.git && cd vllm-fork | ||
USE_CPP=0 VLLM_USE_PRECOMPILED=1 pip install -e . -vvv && cd - | ||
|
||
# Command to save model: | ||
python quantize.py \ | ||
--model_name_or_path meta-llama/Llama-3.3-70B-Instruct/ \ | ||
--quantize \ | ||
--iters 0 \ | ||
--dtype MXFP4 \ | ||
--save_path Llama-3.3-70B-Instruct-MXFP4 \ | ||
--save \ | ||
--save_format llm_compressor | ||
|
||
# Command to inference with vLLM: | ||
CUDA_VISIBLE_DEVICES=0,1 VLLM_USE_V1=0 VLLM_USE_MXFP4_CT_EMULATIONS=1 VLLM_LOGGING_LEVEL=DEBUG \ | ||
vllm serve Llama-3.3-70B-Instruct-MXFP4 --tensor-parallel-size=2 --port 7777 --host localhost --trust-remote-code --dtype bfloat16 --enforce-eager | ||
export no_proxy="localhost, 127.0.0.1, ::1" | ||
curl -X POST http://localhost:7777/v1/completions \ | ||
-H "Content-Type: application/json" \ | ||
-d '{ | ||
"model": "/data0/suyue/Llama-3.3-70B-Instruct-MXFP4", | ||
"prompt": "Solve the following math problem step by step: What is 25 + 37? Please answer directly with the result.", | ||
"max_tokens": 100, | ||
"temperature": 0.7, | ||
"top_p": 1.0 | ||
}' | ||
``` | ||
> Note: To inference with transformers, please save model with `--save_format auto_round` and try `python run_hf_inf.py ${model_name_or_path}` | ||
|
||
### MXFP4 + MXFP8 | ||
Model with mixed precision is not supported in vLLM, but supported in transformers in `auto-round` format. | ||
|
||
```bash | ||
# Command to save model: | ||
python quantize.py \ | ||
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \ | ||
--quantize \ | ||
--iters 0 \ | ||
--dtype MXFP4 \ | ||
--use_recipe \ | ||
--recipe_file recipes/Meta-Llama-3.1-8B-Instruct_7bits.json \ | ||
--save \ | ||
--save_format auto_round \ | ||
--save_path Llama-3.1-8B-Instruct-MXFP4-MXFP8-AR | ||
|
||
# Command to inference with transformer: | ||
python run_hf_inf.py Llama-3.1-8B-Instruct-MXFP4-MXFP8-AR | ||
``` |
258 changes: 258 additions & 0 deletions
258
...i/pytorch/nlp/huggingface_models/language-modeling/quantization/mix-precision/quantize.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
|
||
import torch | ||
import transformers | ||
|
||
# For reproducibility | ||
torch.manual_seed(42) | ||
torch.use_deterministic_algorithms(True, warn_only=True) | ||
######################## HPU Memory Optimization ########################### | ||
# ensure that unnecessary memory is released during quantization. | ||
os.environ.setdefault("PT_HPU_LAZY_MODE", "1") | ||
os.environ.setdefault("PT_HPU_WEIGHT_SHARING", "0") | ||
if int(os.getenv("WORLD_SIZE", "0")) > 0: | ||
os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0") | ||
os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true") | ||
from neural_compressor.torch.utils import is_hpex_available, world_size | ||
from auto_round import AutoRound | ||
|
||
if is_hpex_available(): | ||
import habana_frameworks.torch.core as htcore | ||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph | ||
|
||
htcore.hpu_set_env() | ||
############################################################################ | ||
|
||
|
||
def initialize_model_and_tokenizer(model_name_or_path): | ||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) | ||
config = transformers.AutoConfig.from_pretrained(model_name_or_path) | ||
# using memory mapping with torch_dtype=config.torch_dtype | ||
model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=config.torch_dtype) | ||
# shard model for multi-cards and enable hpu graph | ||
|
||
if world_size > 1: | ||
ds_inference_kwargs = { | ||
"dtype": config.torch_dtype, | ||
"tensor_parallel": {"tp_size": world_size}, | ||
} | ||
import deepspeed | ||
|
||
ds_model = deepspeed.init_inference(model, **ds_inference_kwargs) | ||
model = ds_model.module | ||
model.eval() | ||
return model, tokenizer | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Habana FP8 quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
) | ||
parser.add_argument( | ||
"--model_name_or_path", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct", help="model name or path" | ||
) | ||
parser.add_argument("--dtype", type=str, default="MXFP4", choices=["MXFP4", "MXFP8", "NVFP4", "NVFP4+", "uNVFP4"], help="data type") | ||
parser.add_argument("--quantize", action="store_true", help="whether to quantize model") | ||
parser.add_argument("--device_map", type=str, default=None, help="device map for model") | ||
parser.add_argument("--use_recipe", action="store_true", help="whether to use recipe to quantize model") | ||
parser.add_argument("--recipe_file", type=str, default="recipes/Meta-Llama-3.1-8B-Instruct_6bits.json", help="path of recipe file") | ||
parser.add_argument("--iters", default=200, type=int, help="iters for autoround.") | ||
parser.add_argument("--seqlen", default=2048, type=int, help="sequence length for autoround.") | ||
parser.add_argument("--nsamples", default=128, type=int, help="number of samples for autoround.") | ||
parser.add_argument("--save", action="store_true", help="whether to save the quantized model") | ||
parser.add_argument("--save_path", type=str, default="saved_results", help="path to save the quantized model") | ||
parser.add_argument("--save_format", type=str, default="auto_round", help="format to save the quantized model") | ||
parser.add_argument("--quant_lm_head", action="store_true", help="whether to quantize lm_head") | ||
parser.add_argument("--accuracy", action="store_true", help="accuracy measurement") | ||
parser.add_argument("--local_rank", type=int, default=0, metavar="N", help="Local process rank.") | ||
parser.add_argument("--batch_size", default=32, type=int, help="batch size for accuracy evaluation.") | ||
parser.add_argument( | ||
"--tasks", | ||
type=str, | ||
nargs="+", # 接受一个或多个字符串作为列表 | ||
xin3he marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
default=[ | ||
"piqa", | ||
"hellaswag", | ||
"mmlu", | ||
"winogrande", | ||
"lambada_openai", | ||
], # 默认值 | ||
xin3he marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
help="tasks for accuracy validation, text-generation and code-generation tasks are different.", | ||
) | ||
parser.add_argument("--limit", type=int, default=None, help="number of samples for accuracy evaluation") | ||
args = parser.parse_args() | ||
|
||
print("Target data type:", args.dtype) | ||
|
||
model, tokenizer = initialize_model_and_tokenizer(args.model_name_or_path) | ||
device="hpu" if is_hpex_available() else "cuda" | ||
|
||
if args.quantize: | ||
autoround_dtype_mapping = { | ||
"MXFP4": "mx_fp4", | ||
"MXFP8": "mx_fp8", | ||
"NVFP4": "nv_fp4", | ||
"uNVFP4": "fp4_v2", | ||
"NVFP4+": "fp4_v2", | ||
} | ||
args.dtype = autoround_dtype_mapping[args.dtype] | ||
if args.quant_lm_head: | ||
lm_head_config = { | ||
"group_size": 32 if "mx" in args.dtype else 16, | ||
"data_type": args.dtype, | ||
"act_data_type": "fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype, | ||
} | ||
layer_config = {"lm_head": lm_head_config} | ||
|
||
autoround = AutoRound( | ||
model, | ||
tokenizer, | ||
device=device, | ||
device_map="tp" if world_size > 1 else args.device_map, | ||
iters=args.iters, | ||
seqlen=args.seqlen, | ||
nsamples=args.nsamples, | ||
low_gpu_mem_usage=True, | ||
group_size=32 if "mx" in args.dtype else 16, | ||
data_type=args.dtype, | ||
act_data_type="fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype, | ||
layer_config=layer_config if args.quant_lm_head else None, | ||
) | ||
|
||
if args.use_recipe: | ||
############ load recipe results (MXFP4 + MXFP8) ############ | ||
def load_recipe_results(file_path): | ||
import json | ||
with open(file_path, "r") as f: | ||
return json.load(f) | ||
|
||
layer_config = load_recipe_results(args.recipe_file) | ||
if args.quant_lm_head: | ||
mxfp8_config = { | ||
"bits": 8, | ||
"group_size": 32, | ||
"data_type": "mx_fp8", | ||
"act_data_type": "mx_fp8", | ||
} | ||
# ensure lm_head is quantized with mxfp8_config | ||
layer_config.update({"lm_head": mxfp8_config}) | ||
print("In recipe mode, lm_head is quantized with MXFP8.") | ||
autoround.layer_config = layer_config | ||
|
||
autoround.quantize() | ||
model = autoround.model | ||
|
||
if args.accuracy: | ||
# set dtype to BF16 for HPU inference performance | ||
model = model.to(torch.bfloat16) | ||
model = model.eval().to(device) | ||
if is_hpex_available(): | ||
# HPU needs padding to buckets for better performance | ||
# Generation tasks, such as gsm8k and mmlu-pro, may get OOM. | ||
model = wrap_in_hpu_graph(model) | ||
htcore.hpu_inference_initialize(model, mark_only_scales_as_const=True) | ||
from neural_compressor.evaluation.lm_eval import LMEvalParser, evaluate | ||
|
||
tasks = ",".join(args.tasks) | ||
eval_args = LMEvalParser( | ||
model="hf", | ||
user_model=model, | ||
tokenizer=tokenizer, | ||
batch_size=args.batch_size, | ||
tasks=tasks, | ||
device="hpu", | ||
pad_to_buckets=True, | ||
limit=args.limit, | ||
add_bos_token=True, | ||
) | ||
results = evaluate(eval_args) | ||
torch.hpu.synchronize() | ||
all_accuracy = {} | ||
for task_name, task_results in results["results"].items(): | ||
if task_name in ["hellaswag", "lambada_openai", "piqa", "winogrande", "mmlu"]: | ||
accu = task_results["acc,none"] | ||
all_accuracy[task_name] = accu | ||
print(f"Accuracy for {task_name}: {accu:.4f}") | ||
print(f"Overall accuracy: {sum(all_accuracy.values())/len(all_accuracy):.4f}") | ||
else: | ||
# CUDA evaluation support all tasks. | ||
# gsm8k requires add_bos_token=False for better accuracy for llama model. | ||
# model = torch.compile(model) | ||
args.tasks = ["piqa", "hellaswag", "mmlu", "gsm8k"] | ||
all_accuracy = {} | ||
test_gsm8k = False | ||
test_normal = False | ||
if "gsm8k" in args.tasks: | ||
test_gsm8k = True | ||
args.tasks.remove("gsm8k") | ||
if args.tasks: | ||
test_normal = True | ||
import lm_eval | ||
from lm_eval.models.huggingface import HFLM | ||
|
||
########################## gms8k (ahead of normal tasks) ######################### | ||
if test_gsm8k: | ||
lm = HFLM( | ||
pretrained=model, | ||
tokenizer=tokenizer, | ||
add_bos_token=False, | ||
batch_size=args.batch_size, | ||
) | ||
results_gsm8k = lm_eval.simple_evaluate( | ||
lm, | ||
tasks=["gsm8k"], | ||
limit=args.limit, | ||
) | ||
for task_name, task_results in results_gsm8k["results"].items(): | ||
accu = task_results["exact_match,strict-match"] | ||
all_accuracy[task_name] = accu | ||
########################## gms8k end ######################### | ||
if test_normal: | ||
lm = HFLM( | ||
pretrained=model, | ||
tokenizer=tokenizer, | ||
add_bos_token=True, | ||
batch_size=args.batch_size, | ||
) | ||
results = lm_eval.simple_evaluate( | ||
lm, | ||
tasks=args.tasks, | ||
limit=args.limit, | ||
) | ||
for task_name, task_results in results["results"].items(): | ||
if task_name in ["hellaswag", "lambada_openai", "piqa", "winogrande", "mmlu"]: | ||
accu = task_results["acc,none"] | ||
all_accuracy[task_name] = accu | ||
for task_name, accu in all_accuracy.items(): | ||
print(f"Accuracy for {task_name}: {accu:.4f}") | ||
print(f"Overall accuracy: {sum(all_accuracy.values())/len(all_accuracy):.4f}") | ||
|
||
if args.save: | ||
if args.dtype == "nv_fp4": | ||
# using llm_compressor format to save nv_fp4 model | ||
autoround.save_quantized(args.save_path, format=args.save_format) | ||
else: | ||
# using auto_round format to save mx_fp4 and mx_fp8 model | ||
if world_size > 1: | ||
print(f"Suggest to save model without sharding for better reload experience.") | ||
print(f"Setting`--device_map 0,1,2,3` provides pipeline parallel instead of deepspeed tensor parallel.") | ||
output_dir = args.save_path + "/" + args.local_rank + "_" + args.world_size | ||
autoround.save_quantized(output_dir, format=args.save_format) | ||
else: | ||
autoround.save_quantized(args.save_path, format=args.save_format) | ||
print(f"Quantized model in {args.save_format} format is saved to {args.save_path}") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.