-
Notifications
You must be signed in to change notification settings - Fork 349
Add Int8Tensor for clearer interface #3038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
namgyu-youn
wants to merge
21
commits into
pytorch:main
Choose a base branch
from
namgyu-youn:int8-quant
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
08e9095
Summary:
namgyu-youn db23cf3
rename for clearly: Int8PlainInt8Tensor -> Int8Tensor
namgyu-youn b861dbc
add flags for static/dynamic quant
namgyu-youn 9383550
update static/dynamic quantization workflows
namgyu-youn 2c84ba4
add kernel preference unit test
namgyu-youn 8ddddd3
add kernel preference unit test
namgyu-youn bd6f58a
Merge remote-tracking branch 'upstream/main' into int8-quant
namgyu-youn b5cb3c8
fix missing attribute
namgyu-youn 9a51cae
remove kernel preference args
namgyu-youn c53dad0
link new API with old API using version 2
namgyu-youn d300b02
add granularity, block size support
namgyu-youn c43a3ec
Merge branch 'main' into int8-quant
namgyu-youn 590e0b7
add transpose, index selector workflows
namgyu-youn b3d4f3e
remove external zero point
namgyu-youn df79aa8
update int8 quantization API
namgyu-youn 910906b
Merge remote-tracking branch 'upstream/main' into int8-quant
namgyu-youn c61b36e
add static quantization support
namgyu-youn 0a45f90
sync with main branch
namgyu-youn 1251187
split dispatch decorator
namgyu-youn 844d99d
update int8-quant api
namgyu-youn a844678
update type-hint to prevent depenedency issue
namgyu-youn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
222 changes: 222 additions & 0 deletions
222
test/quantization/quantize_/workflows/int8/test_int8_tensor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import copy | ||
import unittest | ||
|
||
import torch | ||
from torch.testing._internal import common_utils | ||
|
||
from torchao.quantization import ( | ||
Int8DynamicActivationInt8WeightConfig, | ||
Int8WeightOnlyConfig, | ||
PerRow, | ||
PerTensor, | ||
quantize_, | ||
) | ||
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine | ||
from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( | ||
Int8Tensor, | ||
QuantizeTensorToInt8Kwargs, | ||
) | ||
from torchao.quantization.utils import compute_error | ||
from torchao.testing.utils import TorchAOIntegrationTestCase | ||
|
||
|
||
# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged | ||
class ToyTwoLinearModel(torch.nn.Module): | ||
def __init__( | ||
self, | ||
input_dim, | ||
hidden_dim, | ||
output_dim, | ||
has_bias=False, | ||
dtype=None, | ||
device=None, | ||
): | ||
super().__init__() | ||
self.dtype = dtype | ||
self.device = device | ||
self.linear1 = torch.nn.Linear( | ||
input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device | ||
) | ||
self.linear2 = torch.nn.Linear( | ||
hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device | ||
) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
return x | ||
|
||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
@common_utils.instantiate_parametrized_tests | ||
class TestInt8Tensor(TorchAOIntegrationTestCase): | ||
def setUp(self): | ||
super().setUp() | ||
torch.manual_seed(42) | ||
self.weight_fp = torch.randn(4, 3, dtype=torch.bfloat16) | ||
self.input_fp = torch.randn(4, 3, dtype=torch.bfloat16) | ||
self.bias = torch.randn(4, dtype=torch.bfloat16) | ||
self.block_size = [4, 3] | ||
|
||
def test_creation_and_attributes(self): | ||
"""Test tensor creation, dtypes, and ranges""" | ||
tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size) | ||
|
||
self.assertEqual(tensor.shape, (4, 3)) | ||
self.assertEqual(tensor.qdata.dtype, torch.int8) | ||
self.assertTrue( | ||
torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) | ||
) | ||
|
||
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
@common_utils.parametrize( | ||
"sizes", | ||
[ | ||
((128,), 256, 128), | ||
((32, 128), 64, 256), | ||
], | ||
) | ||
@common_utils.parametrize( | ||
"config", | ||
[ | ||
Int8DynamicActivationInt8WeightConfig(version=2), | ||
Int8WeightOnlyConfig(version=2), | ||
], | ||
) | ||
def test_int8_linear_variants( | ||
self, | ||
dtype: torch.dtype, | ||
sizes: tuple, | ||
config, | ||
): | ||
M, N, K = sizes | ||
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") | ||
|
||
# Create a linear layer | ||
m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda") | ||
m_q = copy.deepcopy(m) | ||
|
||
# Quantize | ||
quantize_(m_q, config) | ||
|
||
output_original = m(input_tensor) | ||
output_quantized = m_q(input_tensor) | ||
|
||
error = compute_error(output_original, output_quantized) | ||
assert error > 20, f"Quantization error is too high got a SQNR of {error}" | ||
|
||
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
def test_static_quantization(self, dtype): | ||
"""Test static quantization with pre-computed scale""" | ||
K, N = 128, 64 | ||
weight = torch.randn(N, K, dtype=dtype, device="cuda") | ||
input_tensor = torch.randn(32, K, dtype=dtype, device="cuda") | ||
|
||
act_scale, _ = choose_qparams_affine( | ||
input=input_tensor, | ||
mapping_type=MappingType.SYMMETRIC, | ||
block_size=(1, K), | ||
target_dtype=torch.int8, | ||
quant_min=-128, | ||
quant_max=127, | ||
scale_dtype=dtype, | ||
zero_point_dtype=torch.int8, | ||
) | ||
|
||
# Create weight with static quantization | ||
weight_int8 = Int8Tensor.from_hp( | ||
weight, | ||
block_size=[N, K], | ||
act_quant_kwargs=QuantizeTensorToInt8Kwargs( | ||
block_size=[1, K], | ||
static_scale=act_scale, | ||
), | ||
) | ||
|
||
output = torch.nn.functional.linear(input_tensor, weight_int8) | ||
self.assertEqual(output.shape, (32, N)) | ||
self.assertEqual(output.dtype, dtype) | ||
|
||
@unittest.skip("granularity parameter not supported in current API") | ||
@common_utils.parametrize("granularity", [PerTensor(), PerRow()]) | ||
def test_slice_preserves_aliasing(self, granularity): | ||
config = Int8DynamicActivationInt8WeightConfig( | ||
granularity=granularity, version=2 | ||
) | ||
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) | ||
l.weight = torch.nn.Parameter( | ||
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") | ||
) | ||
quantize_(l, config) | ||
param = l.weight | ||
param_data = param.data | ||
param_data = param_data.narrow(0, 0, 512) | ||
# Making sure the aliasing is preserved in sliced quantized Tensor | ||
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() | ||
assert param.data.scale.data_ptr() == param_data.scale.data_ptr() | ||
|
||
@common_utils.parametrize( | ||
"config", | ||
[ | ||
Int8DynamicActivationInt8WeightConfig(version=2), | ||
Int8WeightOnlyConfig(version=2), | ||
], | ||
) | ||
@common_utils.parametrize("device", ["cpu", "cuda"]) | ||
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
def test_slice(self, config, device, dtype): | ||
"""Test tensor slicing""" | ||
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) | ||
quantize_(dummy, config) | ||
|
||
weight1 = dummy.weight.clone().narrow(0, 0, 64) | ||
weight2 = dummy.weight.clone().narrow(1, 0, 128) | ||
|
||
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64)) | ||
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128)) | ||
Comment on lines
+181
to
+182
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add assert for scale as well? |
||
|
||
# Int8DynamicActivationInt8WeightConfig uses per-row (PerRow) | ||
# Int8WeightOnlyConfig uses per-tensor (PerTensor) | ||
if isinstance(config, Int8DynamicActivationInt8WeightConfig): | ||
# PerRow: dim 0 slicing affects scale, dim 1 doesn't | ||
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64)) | ||
self.assertEqual(weight2.scale, dummy.weight.scale) | ||
else: | ||
# PerTensor: scale unchanged by slicing | ||
self.assertEqual(weight1.scale, dummy.weight.scale) | ||
self.assertEqual(weight2.scale, dummy.weight.scale) | ||
|
||
def test_index_select(self): | ||
"""test that `x_0 = x[0]` works when `x` is a 2D `Int8Tensor`.""" | ||
N, K = 256, 512 | ||
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) | ||
x_int8 = Int8Tensor.from_hp(x, block_size=[N, K]) | ||
x_int8_0 = x_int8[0] | ||
torch.testing.assert_close( | ||
x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 | ||
) | ||
|
||
def test_error_handling_and_dequant(self): | ||
"""Test input validation and dequantization accuracy""" | ||
with self.assertRaises((AssertionError, ValueError, RuntimeError)): | ||
Int8Tensor.from_hp(torch.randn(5), [1]) | ||
|
||
with self.assertRaises((AssertionError, ValueError, RuntimeError)): | ||
Int8Tensor.from_hp(self.weight_fp, [1]) | ||
|
||
test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16) | ||
tensor = Int8Tensor.from_hp(test_data, [1, 2]) | ||
|
||
dequantized = torch.ops.aten.dequantize.self(tensor) | ||
self.assertEqual(dequantized.shape, test_data.shape) | ||
self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1) | ||
|
||
|
||
if __name__ == "__main__": | ||
common_utils.run_tests() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.