-
Couldn't load subscription status.
- Fork 354
Drop old quantization flows #3115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
56d16dc
964b7e4
fbb2f2b
04ea5f8
c09839c
39e9d9d
6ef8503
4651e63
79ee25b
df20cb5
c81e214
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,30 +16,6 @@ | |
| _replace_with_custom_fn_if_matches_filter, | ||
| quantize_, | ||
| ) | ||
| from torchao.quantization.subclass import ( | ||
| Int4WeightOnlyQuantizedLinearWeight, | ||
| Int8WeightOnlyQuantizedLinearWeight, | ||
| ) | ||
|
|
||
|
|
||
| def _int8wo_api(mod, **kwargs): | ||
| quantize_(mod, Int8WeightOnlyConfig(**kwargs), set_inductor_config=False) | ||
|
|
||
|
|
||
| def _int8da_int8w_api(mod, **kwargs): | ||
| quantize_( | ||
| mod, | ||
| Int8DynamicActivationInt8WeightConfig(**kwargs), | ||
| set_inductor_config=False, | ||
| ) | ||
|
|
||
|
|
||
| def _int4wo_api(mod, **kwargs): | ||
| kwargs_copy = kwargs.copy() | ||
| if "groupsize" in kwargs_copy: | ||
| kwargs_copy["group_size"] = kwargs_copy["groupsize"] | ||
| del kwargs_copy["groupsize"] | ||
| quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy), set_inductor_config=False) | ||
|
|
||
|
|
||
| class ToyLinearModel(torch.nn.Module): | ||
|
|
@@ -117,38 +93,18 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): | |
| return _ref_change_linear_weights_to_woqtensors | ||
|
|
||
|
|
||
| _ref_change_linear_weights_to_int8_woqtensors = ( | ||
| _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) | ||
| ) | ||
| _ref_change_linear_weights_to_int4_woqtensors = ( | ||
| _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) | ||
| ) | ||
|
|
||
|
|
||
| torch._dynamo.config.cache_size_limit = 50000 | ||
|
|
||
|
|
||
| @torch.no_grad | ||
| def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): | ||
| if kwargs is None: | ||
| kwargs = {} | ||
|
|
||
| def _bench_quantized_tensor_subclass_perf(api, config, M, N, K): | ||
| m = ToyLinearModel( | ||
| M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda" | ||
| ).eval() | ||
| m_bf16 = copy.deepcopy(m) | ||
| m_ref = copy.deepcopy(m) | ||
| example_inputs = m.example_inputs() | ||
|
|
||
| api(m, **kwargs) | ||
|
|
||
| # reference | ||
| ref_api(m_ref, **kwargs) | ||
|
|
||
| res = m(*example_inputs) | ||
| ref = m_ref(*example_inputs) | ||
|
|
||
| assert torch.equal(res, ref) | ||
| api(m, config) # Pass both model and config | ||
|
|
||
| # perf comparison | ||
| from torchao.utils import benchmark_model | ||
|
|
@@ -158,22 +114,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): | |
| RUNS = 100 | ||
|
|
||
| torch._dynamo.reset() | ||
| m_ref = torch.compile(m_ref, mode="max-autotune", fullgraph=True) | ||
| benchmark_model(m_ref, WARMUP, example_inputs) | ||
| ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs) | ||
| m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True) | ||
| benchmark_model(m_bf16, WARMUP, example_inputs) | ||
| bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) | ||
|
|
||
| torch._dynamo.reset() | ||
| m = torch.compile(m, mode="max-autotune", fullgraph=True) | ||
| benchmark_model(m, WARMUP, example_inputs) | ||
| elapsed_time = benchmark_model(m, RUNS, example_inputs) | ||
|
|
||
| torch._dynamo.reset() | ||
| m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True) | ||
| benchmark_model(m_bf16, WARMUP, example_inputs) | ||
| bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) | ||
|
|
||
| print( | ||
| f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}" | ||
| f"{(M, N, K)}: elapsed time: {elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}" | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -182,24 +133,32 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): | |
| (20, 2048, 2048), | ||
| ] | ||
|
|
||
| print("_int8da_int8w_api") | ||
|
|
||
| print("Int8DynamicActivationInt8WeightConfig") | ||
| for M, N, K in all_shapes: | ||
| _bench_quantized_tensor_subclass_perf( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Temporarily updated to use new APIs 2 times to fix CI, but maybe we can update |
||
| _int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K | ||
| quantize_, | ||
| Int8DynamicActivationInt8WeightConfig(), | ||
| M, | ||
| N, | ||
| K, | ||
| ) | ||
|
|
||
| print("_int8wo_api") | ||
|
|
||
| print("Int8WeightOnlyConfig") | ||
| for M, N, K in all_shapes: | ||
| _bench_quantized_tensor_subclass_perf( | ||
| _int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K | ||
| quantize_, | ||
| Int8WeightOnlyConfig(), | ||
| M, | ||
| N, | ||
| K, | ||
| ) | ||
|
|
||
| print("_int4wo_api") | ||
| kwargs = {"groupsize": 32, "version": 1} | ||
|
|
||
| print("Int4WeightOnlyConfig") | ||
| for M, N, K in all_shapes: | ||
| _bench_quantized_tensor_subclass_perf( | ||
| _int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs | ||
| quantize_, | ||
| Int4WeightOnlyConfig(group_size=32), | ||
| M, | ||
| N, | ||
| K, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about inlining
_int8wo_api,_int8da_int8w_api,_int4wo_api? They are used only once across codebase.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think that's fine if they're only used in benchmarks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also cc @jainapurva, can you take a look at the benchmark changes?