-
Notifications
You must be signed in to change notification settings - Fork 352
Add Int8Tensor for clearer interface #3038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
08e9095
db23cf3
b861dbc
9383550
2c84ba4
8ddddd3
bd6f58a
b5cb3c8
9a51cae
c53dad0
d300b02
c43a3ec
590e0b7
b3d4f3e
df79aa8
910906b
c61b36e
0a45f90
1251187
844d99d
a844678
2c0389a
bafeb43
7006cae
49a7a89
062f3cc
680cec9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from torch.testing._internal.common_utils import run_tests | ||
|
|
||
| from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( | ||
| Int8Tensor, | ||
| QuantizeTensorToInt8Kwargs, | ||
| ) | ||
| from torchao.quantization.utils import compute_error | ||
| from torchao.testing.utils import TorchAOIntegrationTestCase | ||
|
|
||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| class TestInt8Tensor(TorchAOIntegrationTestCase): | ||
| def setUp(self): | ||
| super().setUp() | ||
| torch.manual_seed(42) | ||
| self.weight_fp = torch.randn(4, 3, dtype=torch.float32) | ||
| self.input_fp = torch.randn(2, 3, dtype=torch.float32) | ||
| self.bias = torch.randn(4) | ||
| self.block_size = [4, 3] | ||
|
|
||
| def test_creation_and_attributes(self): | ||
| """Test tensor creation, dtypes, and ranges""" | ||
| tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size) | ||
|
|
||
| self.assertEqual(tensor.shape, (4, 3)) | ||
| self.assertEqual(tensor.qdata.dtype, torch.int8) | ||
| self.assertTrue( | ||
| torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) | ||
| ) | ||
|
|
||
| def test_linear_operations(self): | ||
| """Test fp+int8 and int8+int8 linear ops with quantization error check""" | ||
| weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) | ||
| input_q8 = Int8Tensor.from_hp(self.input_fp, self.block_size) | ||
|
|
||
| reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias) | ||
| result_fp = torch.nn.functional.linear(self.input_fp, weight_q8, self.bias) | ||
| result_q8 = torch.nn.functional.linear(input_q8, weight_q8, self.bias) | ||
|
|
||
| self.assertEqual(result_fp.shape, reference.shape) | ||
| self.assertEqual(result_q8.shape, reference.shape) | ||
| self.assertTrue(compute_error(result_fp, reference) > 10) | ||
| self.assertTrue(compute_error(result_q8, reference) > 10) | ||
|
|
||
| def test_dynamic_quantization(self): | ||
| weight_q8_dynamic = Int8Tensor.from_hp( | ||
| self.weight_fp, | ||
| self.block_size, | ||
| act_quant_kwargs=QuantizeTensorToInt8Kwargs(), | ||
| ) | ||
|
|
||
| reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias) | ||
| result_dynamic = torch.nn.functional.linear( | ||
| self.input_fp, weight_q8_dynamic, self.bias | ||
| ) | ||
|
|
||
| self.assertEqual(result_dynamic.shape, reference.shape) | ||
|
||
|
|
||
| def test_error_handling_and_dequant(self): | ||
| """Test input validation and dequantization accuracy""" | ||
| # Test 1D tensor validation | ||
| with self.assertRaises((AssertionError, ValueError, RuntimeError)): | ||
| Int8Tensor.from_hp(torch.randn(5), [1]) | ||
|
|
||
| # Test wrong block_size validation | ||
| with self.assertRaises((AssertionError, ValueError, RuntimeError)): | ||
| Int8Tensor.from_hp(self.weight_fp, [1]) | ||
|
|
||
| # Test dequantization with exact values | ||
| test_data = torch.tensor([[1.0, -1.0]], dtype=torch.float32) | ||
| tensor = Int8Tensor.from_hp(test_data, [1, 1]) | ||
|
|
||
| dequantized = torch.ops.aten.dequantize.self(tensor) | ||
| self.assertEqual(dequantized.shape, test_data.shape) | ||
| self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor( | |
| """ | ||
| from torchao.quantization.quantize_.workflows import ( | ||
| Float8Tensor, | ||
| Int8Tensor, | ||
| QuantizeTensorToFloat8Kwargs, | ||
| QuantizeTensorToInt8Kwargs, | ||
| ) | ||
|
|
||
| if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): | ||
|
|
@@ -52,5 +54,11 @@ def _choose_quant_func_and_quantize_tensor( | |
| quant_kwargs.hp_value_ub, | ||
| quant_kwargs.kernel_preference, | ||
| ) | ||
| elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): | ||
| return Int8Tensor.from_hp( | ||
| tensor, | ||
| quant_kwargs.block_size or [1, tensor.shape[-1]], | ||
|
||
| kernel_preference=quant_kwargs.kernel_preference, | ||
| ) | ||
|
|
||
| raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,193 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||||||||||||||||||||||||||||||||||||||
| # All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||
| # This source code is licensed under the BSD 3-Clause license found in the | ||||||||||||||||||||||||||||||||||||||||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Optional | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| from torchao.quantization.quantize_.common import ( | ||||||||||||||||||||||||||||||||||||||||||||||
| KernelPreference, | ||||||||||||||||||||||||||||||||||||||||||||||
| QuantizeTensorKwargs, | ||||||||||||||||||||||||||||||||||||||||||||||
| _choose_quant_func_and_quantize_tensor, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| from torchao.utils import TorchAOBaseTensor | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| __all__ = ["Int8Tensor", "QuantizeTensorToInt8Kwargs"] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| aten = torch.ops.aten | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||||||||||||||||||
| class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): | ||||||||||||||||||||||||||||||||||||||||||||||
| """Tensor kwargs for creating int8 tensor (either activation or weight) | ||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||
| kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. | ||||||||||||||||||||||||||||||||||||||||||||||
| block_size (Optional[list[int]]): block size for quantization granularity | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| kernel_preference: KernelPreference = KernelPreference.AUTO | ||||||||||||||||||||||||||||||||||||||||||||||
jerryzh168 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||
| block_size: Optional[list[int]] = None | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: Implement block-wise quantization using block_size | ||||||||||||||||||||||||||||||||||||||||||||||
| class Int8Tensor(TorchAOBaseTensor): | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| int8 quantized tensor with plain layout | ||||||||||||||||||||||||||||||||||||||||||||||
| Tensor Attributes: | ||||||||||||||||||||||||||||||||||||||||||||||
| qdata: (N, K) int8 quantized weight data | ||||||||||||||||||||||||||||||||||||||||||||||
| scale: scale factors for dequantization | ||||||||||||||||||||||||||||||||||||||||||||||
| zero_point: zero points for dequantization | ||||||||||||||||||||||||||||||||||||||||||||||
| Non-Tensor Attributes: | ||||||||||||||||||||||||||||||||||||||||||||||
| block_size: block size for quantization granularity | ||||||||||||||||||||||||||||||||||||||||||||||
| shape: original tensor shape | ||||||||||||||||||||||||||||||||||||||||||||||
| act_quant_kwargs: flags for static/dynamic activation quantization | ||||||||||||||||||||||||||||||||||||||||||||||
| kernel_preference: kernel preference for operations | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| tensor_data_names = ["qdata", "scale", "zero_point"] | ||||||||||||||||||||||||||||||||||||||||||||||
| tensor_attribute_names = ["block_size"] | ||||||||||||||||||||||||||||||||||||||||||||||
| optional_tensor_attribute_names = [ | ||||||||||||||||||||||||||||||||||||||||||||||
| "act_quant_kwargs", | ||||||||||||||||||||||||||||||||||||||||||||||
| "kernel_preference", | ||||||||||||||||||||||||||||||||||||||||||||||
| "dtype", | ||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def __new__( | ||||||||||||||||||||||||||||||||||||||||||||||
jerryzh168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||
| cls, | ||||||||||||||||||||||||||||||||||||||||||||||
| qdata, | ||||||||||||||||||||||||||||||||||||||||||||||
| scale, | ||||||||||||||||||||||||||||||||||||||||||||||
| zero_point, | ||||||||||||||||||||||||||||||||||||||||||||||
| block_size, | ||||||||||||||||||||||||||||||||||||||||||||||
| shape, | ||||||||||||||||||||||||||||||||||||||||||||||
| act_quant_kwargs=None, | ||||||||||||||||||||||||||||||||||||||||||||||
| kernel_preference=KernelPreference.AUTO, | ||||||||||||||||||||||||||||||||||||||||||||||
| dtype=None, | ||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||
| kwargs = { | ||||||||||||||||||||||||||||||||||||||||||||||
| "device": qdata.device, | ||||||||||||||||||||||||||||||||||||||||||||||
| "dtype": dtype or scale.dtype, | ||||||||||||||||||||||||||||||||||||||||||||||
| "requires_grad": False, | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||
jerryzh168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||
| qdata, | ||||||||||||||||||||||||||||||||||||||||||||||
| scale, | ||||||||||||||||||||||||||||||||||||||||||||||
| zero_point, | ||||||||||||||||||||||||||||||||||||||||||||||
| block_size, | ||||||||||||||||||||||||||||||||||||||||||||||
| shape, | ||||||||||||||||||||||||||||||||||||||||||||||
| act_quant_kwargs=None, | ||||||||||||||||||||||||||||||||||||||||||||||
| kernel_preference=KernelPreference.AUTO, | ||||||||||||||||||||||||||||||||||||||||||||||
| dtype=None, | ||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||||||
| self.qdata = qdata | ||||||||||||||||||||||||||||||||||||||||||||||
| self.scale = scale | ||||||||||||||||||||||||||||||||||||||||||||||
| self.zero_point = zero_point | ||||||||||||||||||||||||||||||||||||||||||||||
| self.block_size = block_size | ||||||||||||||||||||||||||||||||||||||||||||||
| self.act_quant_kwargs = act_quant_kwargs | ||||||||||||||||||||||||||||||||||||||||||||||
| self.kernel_preference = kernel_preference | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def __repr__(self): | ||||||||||||||||||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"{self.zero_point=}, {self.block_size=}, {self.kernel_preference=}, " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"{self.shape=}, {self.device=}, {self.dtype=})" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||
| def from_hp( | ||||||||||||||||||||||||||||||||||||||||||||||
| cls, | ||||||||||||||||||||||||||||||||||||||||||||||
| w: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||
| block_size: list[int], | ||||||||||||||||||||||||||||||||||||||||||||||
| act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| kernel_preference: KernelPreference = KernelPreference.AUTO, | ||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||
| if w.dim() != 2 or len(block_size) != 2: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Expected 2D tensor and block_size length 2") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Rounding function from high precision dtype | ||||||||||||||||||||||||||||||||||||||||||||||
| scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0 | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): |
also this should be using these quant primitive ops:
ao/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py
Lines 79 to 97 in 8c5c33e
| scale, zero_point = choose_qparams_affine( | |
| input=preprocessed_w, | |
| mapping_type=MappingType.SYMMETRIC, | |
| block_size=block_size, | |
| target_dtype=target_dtype, | |
| quant_min=quant_min, | |
| quant_max=quant_max, | |
| eps=1e-6, | |
| ) | |
| wq = quantize_affine( | |
| input=preprocessed_w, | |
| block_size=block_size, | |
| scale=scale, | |
| zero_point=zero_point, | |
| output_dtype=target_dtype, | |
| quant_min=quant_min, | |
| quant_max=quant_max, | |
| ) |
ao/torchao/quantization/quant_api.py
Line 1566 in 8c5c33e
| new_weight = to_affine_quantized_intx( |
ao/torchao/dtypes/affine_quantized_tensor.py
Line 325 in 8c5c33e
| scale, zero_point = choose_qparams_affine( |
this might require a bit too much context, let me know if you would like us to take over
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, surely want to take over! Drafted this PR for those updates, but will look into it today (6 hours later)
btw, version 2 is updated at c53dad0 (version 1 is default)
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed? if not we should remove for now
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: implements is refactored now: https://github.com/pytorch/ao/pull/2866/files
jerryzh168 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also this is the code for weight only quant I think:
ao/torchao/dtypes/uintx/plain_layout.py
Line 250 in 122b307
| def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done at 9383550 , thanks for pointing it out.
Uh oh!
There was an error while loading. Please reload this page.