Skip to content

Commit e1be060

Browse files
committed
update per review comments
Signed-off-by: He, Xin3 <[email protected]>
1 parent dd1e95b commit e1be060

File tree

2 files changed

+45
-19
lines changed
  • examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mix-precision

2 files changed

+45
-19
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mix-precision/README.md

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Run
22

3-
In this examples, you can verify the accuracy on HPU/CUDA device with emulation of MXFP4, MXFP8, NVFP4 and NVFP4+.
3+
In this examples, you can verify the accuracy on HPU/CUDA device with emulation of MXFP4, MXFP8, NVFP4 and uNVFP4.
44

55
## Requirement
66

@@ -16,33 +16,50 @@ pip install git+https://github.com/intel/auto-round.git@xinhe/llama_tmp
1616
### Demo
1717

1818
```bash
19-
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype mx_fp4 --batch_size 8 --accuracy
19+
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype MXFP4 --batch_size 8 --accuracy
2020
```
2121

22-
> Note: `--dtype` supports `mx_fp4`(MXFP4), `mx_fp8`(MXFP8), `nv_fp4`(NVFP4), `fp4_v2`(NVFP4+)
22+
> Note: `--dtype` supports `MXFP4`, `MXFP8`, `NVFP4`, `uNVFP4`
2323
24-
## Mix-precision Quantization (MXFP4 + MXFP8, Target bits: 6)
24+
## Mix-precision Quantization (MXFP4 + MXFP8)
2525

2626
```bash
2727
# Llama 3.1 8B
2828
python quantize.py \
2929
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
3030
--quantize \
31-
--dtype mx_fp4 \
31+
--dtype MXFP4 \
3232
--use_recipe \
3333
--recipe_file recipes/Meta-Llama-3.1-8B-Instruct_7bits.json \
3434
--accuracy \
3535
--batch_size 32
3636

37-
3837
# Llama 3.3 70B
3938
deepspeed --include="localhost:4,5,6,7" --master_port=29500 quantize.py \
4039
--model_name_or_path meta-llama/Llama-3.3-70B-Instruct/ \
4140
--quantize \
42-
--dtype mx_fp4 \
41+
--dtype MXFP4 \
4342
--use_recipe \
4443
--recipe_file recipes/Meta-Llama-3.3-70B-Instruct_6bits.json \
4544
--accuracy \
4645
--batch_size 32
4746
```
4847

48+
## vLLM usage
49+
NVFP4 is supported by vLLM already, the saved model in this example follows the `llm_compressor` format, please refer to the usage in the public vLLM document.
50+
51+
MXFP4 is enabled in a forked repo, usages as below:
52+
```bash
53+
# Install the forked vLLM for MXFP4
54+
55+
# Command to save model:
56+
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype MXFP4 --batch_size 8 --save --save_path opt-125m-mxfp4
57+
58+
# Command to evaluate with vLLM:
59+
60+
```
61+
62+
> Notes:
63+
> 1. model quantized with deepspeed tensor parallel is not supported to be saved.
64+
> 2. model quantized with recipe is not supported to be saved.
65+

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mix-precision/quantize.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def initialize_model_and_tokenizer(model_name_or_path):
6767
parser.add_argument(
6868
"--model_name_or_path", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct", help="model name or path"
6969
)
70-
parser.add_argument("--dtype", type=str, default="mx_fp4", choices=["mx_fp4", "mx_fp8", "nv_fp2", "fp4_v2"], help="data type")
70+
parser.add_argument("--dtype", type=str, default="MXFP4", choices=["MXFP4", "MXFP8", "NVFP4", "NVFP4+", "uNVFP4"], help="data type")
7171
parser.add_argument("--quantize", action="store_true", help="whether to quantize model")
7272
parser.add_argument("--use_recipe", action="store_true", help="whether to use recipe to quantize model")
7373
parser.add_argument("--recipe_file", type=str, default="recipes/Meta-Llama-3.1-8B-Instruct_6bits.json", help="path of recipe file")
@@ -80,13 +80,6 @@ def initialize_model_and_tokenizer(model_name_or_path):
8080
parser.add_argument("--accuracy", action="store_true", help="accuracy measurement")
8181
parser.add_argument("--local_rank", type=int, default=0, metavar="N", help="Local process rank.")
8282
parser.add_argument("--batch_size", default=32, type=int, help="batch size for accuracy evaluation.")
83-
parser.add_argument(
84-
"--mxfp8_mod_list",
85-
type=str,
86-
nargs="*",
87-
default=[], # 默认值
88-
help="List of module names or patterns for MXFP8 quantization.",
89-
)
9083
parser.add_argument(
9184
"--tasks",
9285
type=str,
@@ -109,6 +102,14 @@ def initialize_model_and_tokenizer(model_name_or_path):
109102
device="hpu" if is_hpex_available() else "cuda"
110103

111104
if args.quantize:
105+
autoround_dtype_mapping = {
106+
"MXFP4": "mx_fp4",
107+
"MXFP8": "mx_fp8",
108+
"NVFP4": "nv_fp4",
109+
"uNVFP4": "fp4_v2",
110+
"NVFP4+": "fp4_v2",
111+
}
112+
args.dtype = autoround_dtype_mapping[args.dtype]
112113
if args.quant_lm_head:
113114
lm_head_config = {
114115
"group_size": 32 if "mx" in args.dtype else 16,
@@ -155,11 +156,10 @@ def load_recipe_results(file_path):
155156
autoround.quantize()
156157
model = autoround.model
157158

158-
# set dtype to BF16 for HPU inference performance
159-
model = model.to(torch.bfloat16)
160-
model = model.eval().to(device)
161-
162159
if args.accuracy:
160+
# set dtype to BF16 for HPU inference performance
161+
model = model.to(torch.bfloat16)
162+
model = model.eval().to(device)
163163
if is_hpex_available():
164164
# HPU needs padding to buckets for better performance
165165
# Generation tasks, such as gsm8k and mmlu-pro, may get OOM.
@@ -240,3 +240,12 @@ def load_recipe_results(file_path):
240240
for task_name, accu in all_accuracy.items():
241241
print(f"Accuracy for {task_name}: {accu:.4f}")
242242
print(f"Overall accuracy: {sum(all_accuracy.values())/len(all_accuracy):.4f}")
243+
244+
if args.save:
245+
if world_size > 1:
246+
assert False, "model quantized with deepspeed tensor parallel is not supported to be saved."
247+
elif args.use_recipe:
248+
assert False, "model quantized with recipe is not supported to be saved."
249+
else:
250+
autoround.save_quantized(args.save_path, format="llm_compressor")
251+
print(f"Quantized model is saved to {args.save_path}")

0 commit comments

Comments
 (0)