Skip to content

Commit d6149aa

Browse files
authored
Update the PT2E CV example (#2032)
Signed-off-by: yiliu30 <[email protected]>
1 parent 08ec908 commit d6149aa

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

examples/3.x_api/pytorch/cv/static_quant/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ This implements quantization of popular model architectures, such as ResNet on t
1414
To quant a model and validate accaracy, run `main.py` with the desired model architecture and the path to the ImageNet dataset:
1515

1616
```bash
17-
python main.py -a resnet18 [imagenet-folder with train and val folders] -q -e
17+
export ImageNetDataPath=/path/to/imagenet
18+
python main.py $ImageNetDataPath --pretrained -a resnet18 --tune --calib_iters 5
1819
```
1920

2021

examples/3.x_api/pytorch/cv/static_quant/main.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
4747
metavar='W', help='weight decay (default: 1e-4)',
4848
dest='weight_decay')
49+
parser.add_argument('--w_granularity', default="per_channel", type=str, choices=["per_channel", "per_tensor"], help='weight granularity')
4950
parser.add_argument('-p', '--print-freq', default=10, type=int,
5051
metavar='N', help='print frequency (default: 10)')
5152
parser.add_argument('--resume', default='', type=str, metavar='PATH',
@@ -179,7 +180,7 @@ def eval_func(model):
179180

180181
if args.tune:
181182
from neural_compressor.torch.export import export
182-
from neural_compressor.torch.quantization import prepare, convert, get_default_static_config
183+
from neural_compressor.torch.quantization import prepare, convert, StaticQuantConfig
183184

184185
# Prepare the float model and example inputs for exporting model
185186
x = torch.randn(args.batch_size, 3, 224, 224).contiguous(memory_format=torch.channels_last)
@@ -188,15 +189,15 @@ def eval_func(model):
188189
# Specify that the first dimension of each input is that batch size
189190
from torch.export import Dim
190191
print(args.batch_size)
191-
batch = Dim("batch", min=16)
192+
batch = Dim("batch")
192193

193194
# Specify that the first dimension of each input is that batch size
194195
dynamic_shapes = {"x": {0: batch}}
195196

196197
# Export eager model into FX graph model
197198
exported_model = export(model=model, example_inputs=example_inputs, dynamic_shapes=dynamic_shapes)
198199
# Quantize the model
199-
quant_config = get_default_static_config()
200+
quant_config = StaticQuantConfig(w_granularity=args.w_granularity)
200201

201202
prepared_model = prepare(exported_model, quant_config=quant_config)
202203
# Calibrate
@@ -233,7 +234,9 @@ def eval_func(model):
233234
new_model = opt_model
234235
else:
235236
new_model = model
237+
# For fair comparison, we also compile the float model
236238
new_model.eval()
239+
new_model = torch.compile(new_model)
237240
if args.performance:
238241
benchmark(val_loader, new_model, args)
239242
return

examples/3.x_api/pytorch/cv/static_quant/run_benchmark.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function run_benchmark {
7979
python main.py \
8080
--pretrained \
8181
-a resnet18 \
82-
-b 30 \
82+
-b ${batch_size} \
8383
--tuned_checkpoint ${tuned_checkpoint} \
8484
${dataset_location} \
8585
${extra_cmd} \
@@ -89,7 +89,7 @@ function run_benchmark {
8989
main.py \
9090
--pretrained \
9191
-a resnet18 \
92-
-b 30 \
92+
-b ${batch_size} \
9393
--tuned_checkpoint ${tuned_checkpoint} \
9494
${dataset_location} \
9595
${extra_cmd} \

examples/3.x_api/pytorch/cv/static_quant/run_quant.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ function main {
1010

1111
# init params
1212
function init_params {
13+
batch_size=16
1314
tuned_checkpoint="saved_results"
1415
for var in "$@"
1516
do
@@ -22,6 +23,9 @@ function init_params {
2223
;;
2324
--input_model=*)
2425
input_model=$(echo $var |cut -f2 -d=)
26+
;;
27+
--batch_size=*)
28+
batch_size=$(echo $var |cut -f2 -d=)
2529
;;
2630
--output_model=*)
2731
tuned_checkpoint=$(echo $var |cut -f2 -d=)
@@ -44,7 +48,7 @@ function run_tuning {
4448
--pretrained \
4549
-t \
4650
-a resnet18 \
47-
-b 30 \
51+
-b ${batch_size} \
4852
--tuned_checkpoint ${tuned_checkpoint} \
4953
${dataset_location}
5054
}

0 commit comments

Comments
 (0)