diff --git a/docs/source/quantization_overview.rst b/docs/source/quantization_overview.rst index f5c82bfe5f..df0a924b11 100644 --- a/docs/source/quantization_overview.rst +++ b/docs/source/quantization_overview.rst @@ -5,7 +5,7 @@ First we want to lay out the torchao stack:: Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc. --------------------------------------------------------------------------------------------- - Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor + Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor --------------------------------------------------------------------------------------------- Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize --------------------------------------------------------------------------------------------- @@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma - scaled int4 - preshuffled (special format to optimize for loading) - float8 act + int4 weight dynamic quantization and int4 weight only quantization + * - Int8Tensor + - plain .. note:: We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options. diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py new file mode 100644 index 0000000000..f1c4ea240c --- /dev/null +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from torch.testing._internal import common_utils + +from torchao.quantization import ( + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + PerRow, + PerTensor, + quantize_, +) +from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine +from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( + Int8Tensor, + QuantizeTensorToInt8Kwargs, +) +from torchao.quantization.utils import compute_error +from torchao.testing.utils import TorchAOIntegrationTestCase + + +# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged +class ToyTwoLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + has_bias=False, + dtype=None, + device=None, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device + ) + self.linear2 = torch.nn.Linear( + hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@common_utils.instantiate_parametrized_tests +class TestInt8Tensor(TorchAOIntegrationTestCase): + def setUp(self): + super().setUp() + torch.manual_seed(42) + self.weight_fp = torch.randn(4, 3, dtype=torch.bfloat16) + self.input_fp = torch.randn(4, 3, dtype=torch.bfloat16) + self.bias = torch.randn(4, dtype=torch.bfloat16) + self.block_size = [4, 3] + + def test_creation_and_attributes(self): + """Test tensor creation, dtypes, and ranges""" + tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size) + + self.assertEqual(tensor.shape, (4, 3)) + self.assertEqual(tensor.qdata.dtype, torch.int8) + self.assertTrue( + torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) + ) + + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ], + ) + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + def test_int8_linear_variants( + self, + dtype: torch.dtype, + sizes: tuple, + config, + ): + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + + # Create a linear layer + m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda") + m_q = copy.deepcopy(m) + + # Quantize + quantize_(m_q, config) + + output_original = m(input_tensor) + output_quantized = m_q(input_tensor) + + error = compute_error(output_original, output_quantized) + assert error > 20, f"Quantization error is too high got a SQNR of {error}" + + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_static_quantization(self, dtype): + """Test static quantization with pre-computed scale""" + K, N = 128, 64 + weight = torch.randn(N, K, dtype=dtype, device="cuda") + input_tensor = torch.randn(32, K, dtype=dtype, device="cuda") + + act_scale, _ = choose_qparams_affine( + input=input_tensor, + mapping_type=MappingType.SYMMETRIC, + block_size=(input_tensor.shape[0], K), + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=dtype, + zero_point_dtype=torch.int8, + ) + + # Create weight with static quantization + weight_int8 = Int8Tensor.from_hp( + weight, + block_size=[N, K], + act_quant_kwargs=QuantizeTensorToInt8Kwargs( + block_size=[input_tensor.shape[0], K], + static_scale=act_scale, + ), + ) + + output = torch.nn.functional.linear(input_tensor, weight_int8) + self.assertEqual(output.shape, (32, N)) + self.assertEqual(output.dtype, dtype) + + @unittest.skip("granularity parameter not supported in current API") + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + def test_slice_preserves_aliasing(self, granularity): + config = Int8DynamicActivationInt8WeightConfig( + granularity=granularity, version=2 + ) + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + # Making sure the aliasing is preserved in sliced quantized Tensor + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + @common_utils.parametrize("device", ["cpu", "cuda"]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_slice(self, config, device, dtype): + """Test tensor slicing""" + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + quantize_(dummy, config) + + weight1 = dummy.weight.clone().narrow(0, 0, 64) + weight2 = dummy.weight.clone().narrow(1, 0, 128) + + self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64)) + self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128)) + + # Int8DynamicActivationInt8WeightConfig uses per-row (PerRow) + # Int8WeightOnlyConfig uses per-tensor (PerTensor) + if isinstance(config, Int8DynamicActivationInt8WeightConfig): + # PerRow: dim 0 slicing affects scale, dim 1 doesn't + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64)) + self.assertEqual(weight2.scale, dummy.weight.scale) + else: + # PerTensor: scale unchanged by slicing + self.assertEqual(weight1.scale, dummy.weight.scale) + self.assertEqual(weight2.scale, dummy.weight.scale) + + def test_index_select(self): + """test that `x_0 = x[0]` works when `x` is a 2D `Int8Tensor`.""" + N, K = 256, 512 + x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) + x_int8 = Int8Tensor.from_hp(x, block_size=[N, K]) + x_int8_0 = x_int8[0] + torch.testing.assert_close( + x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 + ) + + def test_error_handling_and_dequant(self): + """Test input validation and dequantization accuracy""" + with self.assertRaises((AssertionError, ValueError, RuntimeError)): + Int8Tensor.from_hp(torch.randn(5), [1]) + + with self.assertRaises((AssertionError, ValueError, RuntimeError)): + Int8Tensor.from_hp(self.weight_fp, [1]) + + test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16) + tensor = Int8Tensor.from_hp(test_data, [1, 2]) + + dequantized = torch.ops.aten.dequantize.self(tensor) + self.assertEqual(dequantized.shape, test_data.shape) + self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index c8774e9426..0ed2a20229 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -97,6 +97,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8Tensor, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) @@ -170,6 +171,7 @@ "IntxOpaqueTensor", "IntxUnpackedToInt8Tensor", "Int4TilePackedTo4dTensor", + "Int8Tensor", "Float8Tensor", "Int4OpaqueTensor", # smooth quant - subject to change diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3bda8f91ab..2cedc6e165 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -81,6 +81,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8Tensor, IntxChooseQParamsAlgorithm, IntxOpaqueTensor, IntxPackingFormat, @@ -96,7 +97,6 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( - _ConfigDeprecationWrapper, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -148,18 +148,7 @@ "autoquant", "_get_subclass_inserter", "quantize_", - "int8_dynamic_activation_int4_weight", - "int8_dynamic_activation_int8_weight", - "int8_dynamic_activation_int8_semi_sparse_weight", - "int4_weight_only", - "int8_weight_only", "intx_quantization_aware_training", - "float8_weight_only", - "uintx_weight_only", - "fpx_weight_only", - "gemlite_uintx_weight_only", - "float8_dynamic_activation_float8_weight", - "float8_static_activation_float8_weight", "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", @@ -203,12 +192,6 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ - if isinstance(model, Float8Linear): - with torch.device("meta"): - new_module = nn.Linear(model.in_features, model.out_features) - new_module.weight = model.weight - new_module.bias = model.bias - model = new_module if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization @@ -249,12 +232,6 @@ def _replace_with_custom_fn_if_matches_filter_with_name( Returns: None """ - if isinstance(model, Float8Linear): - with torch.device("meta"): - new_module = nn.Linear(model.in_features, model.out_features) - new_module.weight = model.weight - new_module.bias = model.bias - model = new_module if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization @@ -519,7 +496,7 @@ def quantize_( # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile - from torchao.quantization.quant_api import int4_weight_only + from torchao.quantization.quant_api import Int4WeightOnlyConfig m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1)) @@ -641,12 +618,6 @@ def __post_init__(self): ) -# for BC -int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( - "int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig -) - - @register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) def _int8_dynamic_activation_int4_weight_transform( module: torch.nn.Module, @@ -1012,12 +983,6 @@ def __post_init__(self): ) -# for bc -int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( - "int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig -) - - @register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) def _int4_dynamic_activation_int4_weight_transform( module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig @@ -1075,12 +1040,6 @@ def __post_init__(self): ) -# for BC -gemlite_uintx_weight_only = _ConfigDeprecationWrapper( - "gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig -) - - @register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) def _gemlite_uintx_weight_only_transform( module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig @@ -1158,11 +1117,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig") -# for BC -# TODO maybe change other callsites -int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig) - - def _int4_weight_only_quantize_tensor(weight, config): # TODO(future PR): perhaps move this logic to a different file, to keep the API # file clean of implementation details @@ -1369,32 +1323,37 @@ class Int8WeightOnlyConfig(AOBaseConfig): group_size: Optional[int] = None set_inductor_config: bool = True + version: int = 2 def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") -# for BC -int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig) - - def _int8_weight_only_quantize_tensor(weight, config): - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - group_size = config.group_size - if group_size is None: - group_size = weight.shape[-1] - block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) - new_weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - ) + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Int8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" + ) + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + group_size = config.group_size + if group_size is None: + group_size = weight.shape[-1] + block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + else: + assert config.version == 2, f"Unexpected version: {config.version}" + block_size = [weight.shape[0], weight.shape[1]] + new_weight = Int8Tensor.from_hp(weight, block_size=block_size) return new_weight @@ -1522,12 +1481,14 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): in original precision during decode operations. set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values for better performance with this quantization scheme. + version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Int8Tensor """ layout: Optional[Layout] = PlainLayout() act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC weight_only_decode: bool = False set_inductor_config: bool = True + version: int = 1 def __post_init__(self): torch._C._log_api_usage_once( @@ -1535,12 +1496,6 @@ def __post_init__(self): ) -# for BC -int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper( - "int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig -) - - def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): layout = config.layout act_mapping_type = config.act_mapping_type @@ -1576,18 +1531,36 @@ def get_weight_block_size(x): input_quant_func = _int8_asymm_per_token_quant block_size = get_weight_block_size(weight) - new_weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - _layout=layout, - zero_point_domain=weight_zero_point_domain, - ) - new_weight = to_linear_activation_quantized(new_weight, input_quant_func) - return new_weight + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" + ) + quantized_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=layout, + zero_point_domain=weight_zero_point_domain, + ) + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func + ) + else: + from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( + QuantizeTensorToInt8Kwargs, + ) + + assert config.version == 2, f"Unexpected version: {config.version}" + quantized_weight = Int8Tensor.from_hp( + weight, + block_size, + act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=block_size), + ) + + return quantized_weight @register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) @@ -1646,12 +1619,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") -# for BC -float8_weight_only = _ConfigDeprecationWrapper( - "float8_weight_only", Float8WeightOnlyConfig -) - - def _float8_weight_only_quant_tensor(weight, config): if config.version == 1: warnings.warn( @@ -1687,6 +1654,10 @@ def _float8_weight_only_transform( "applying int8 weight only quant requires module to have weight attribute" + " but {module} does not have one" ) + + if isinstance(module, Float8Linear): + module = _unwrap_float8_linear(module) + new_weight = _float8_weight_only_quant_tensor(module.weight, config) module.weight = torch.nn.Parameter(new_weight, requires_grad=False) @@ -1806,12 +1777,6 @@ def __post_init__(self): self.granularity = [activation_granularity, weight_granularity] -# for bc -float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper( - "float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig -) - - def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype @@ -1896,6 +1861,9 @@ def _float8_dynamic_activation_float8_weight_transform( "applying float8 dynamic activation quant requires module to have weight attribute" + f"but {module} does not have one" ) + if isinstance(module, Float8Linear): + module = _unwrap_float8_linear(module) + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( module.weight, config ) @@ -1931,6 +1899,9 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform( ): assert is_sm_at_least_90(), "Float8 quantization is only supported on CUDA>=9.0" + if isinstance(module, Float8Linear): + module = _unwrap_float8_linear(module) + weight = module.weight weight_dtype = config.weight_dtype activation_dtype = config.activation_dtype @@ -1981,12 +1952,6 @@ def __post_init__(self): ) -# for bc -float8_static_activation_float8_weight = _ConfigDeprecationWrapper( - "float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig -) - - @register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) def _float8_static_activation_float8_weight_transform( module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig @@ -1995,6 +1960,9 @@ def _float8_static_activation_float8_weight_transform( "Float8 static activation quantization is only supported on CUDA 8.9 and above" ) + if isinstance(module, Float8Linear): + module = _unwrap_float8_linear(module) + scale = config.scale activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype @@ -2066,12 +2034,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig") -# for BC -uintx_weight_only = _ConfigDeprecationWrapper( - "uintx_weight_only", UIntXWeightOnlyConfig -) - - @register_quantize_module_handler(UIntXWeightOnlyConfig) def _uintx_weight_only_transform( module: torch.nn.Module, config: UIntXWeightOnlyConfig @@ -2350,10 +2312,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig") -# for BC -fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig) - - @register_quantize_module_handler(FPXWeightOnlyConfig) def _fpx_weight_only_transform( module: torch.nn.Module, config: FPXWeightOnlyConfig @@ -2364,6 +2322,9 @@ def _fpx_weight_only_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() + if isinstance(module, Float8Linear): + module = _unwrap_float8_linear(module) + from torchao.dtypes import to_affine_quantized_fpx from torchao.dtypes.floatx import FloatxTensorCoreLayout @@ -2386,7 +2347,7 @@ def _fpx_weight_only_transform( @dataclass class ModuleFqnToConfig(AOBaseConfig): - """Per module configurations for torchao quantize_ API + r"""Per module configurations for torchao quantize_ API Args: `module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an @@ -2443,6 +2404,21 @@ def _module_fqn_to_config_handler( return module +def _unwrap_float8_linear(module: Float8Linear) -> nn.Linear: + """ + Unwrap a torchao Float8Linear by returning a nn.Linear with the same weights and bias. + + Torchao inference quantization techniques are generally only applicable to nn.Linear + layers, so this helper is useful for unwrapping models trained with torchao float8 training, + which replaces nn.Linear layers with Float8Linear layers. + """ + with torch.device("meta"): + new_module = nn.Linear(module.in_features, module.out_features) + new_module.weight = module.weight + new_module.bias = module.bias + return new_module + + torch.serialization.add_safe_globals( [ _int8_asymm_per_token_quant, diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index 0adc8c786d..44dd09ff62 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor( """ from torchao.quantization.quantize_.workflows import ( Float8Tensor, + Int8Tensor, QuantizeTensorToFloat8Kwargs, + QuantizeTensorToInt8Kwargs, ) if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): @@ -52,5 +54,11 @@ def _choose_quant_func_and_quantize_tensor( quant_kwargs.hp_value_ub, quant_kwargs.kernel_preference, ) + elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): + return Int8Tensor.from_hp( + tensor, + quant_kwargs.block_size, + act_quant_kwargs=quant_kwargs, + ) raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 4307637f8e..3891f28dbe 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -20,6 +20,10 @@ Int4Tensor, ) from .int4.int4_tile_packed_to_4d_tensor import Int4TilePackedTo4dTensor +from .int8.int8_tensor import ( + Int8Tensor, + QuantizeTensorToInt8Kwargs, +) from .intx.intx_choose_qparams_algorithm import IntxChooseQParamsAlgorithm from .intx.intx_opaque_tensor import ( IntxOpaqueTensor, @@ -37,6 +41,8 @@ "Int4MarlinSparseTensor", "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", + "Int8Tensor", + "QuantizeTensorToInt8Kwargs", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py new file mode 100644 index 0000000000..6986874cb3 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -0,0 +1,284 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.quantization.quant_primitives import ( + MappingType, + _maybe_expand_scale_to_tensor_shape, + choose_qparams_affine, + quantize_affine, +) +from torchao.quantization.quantize_.common import ( + QuantizeTensorKwargs, + _choose_quant_func_and_quantize_tensor, +) +from torchao.utils import TorchAOBaseTensor + +__all__ = ["Int8Tensor", "QuantizeTensorToInt8Kwargs"] + +aten = torch.ops.aten + + +@dataclass +class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): + """Tensor kwargs for creating int8 tensor (either activation or weight) + + Args: + block_size (List[int]): block size for quantization granularity + static_scale (Optional[torch.Tensor]): pre-computed scale for static quantization + """ + + block_size: List[int] + static_scale: Optional[torch.Tensor] = None + + +class Int8Tensor(TorchAOBaseTensor): + """ + int8 quantized tensor with plain layout + + Tensor Attributes: + qdata: (N, K) int8 quantized weight data + scale: scale factors for dequantization + + Non-Tensor Attributes: + block_size: block size for quantization granularity + act_quant_kwargs: flags for static/dynamic activation quantization + """ + + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = ["block_size"] + optional_tensor_attribute_names = [ + "act_quant_kwargs", + "dtype", + ] + + def __new__( + cls: type, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: List[int], + act_quant_kwargs=None, + dtype=None, + ): + kwargs = { + "device": qdata.device, + "dtype": dtype or scale.dtype, + "requires_grad": False, + } + return torch.Tensor._make_wrapper_subclass(cls, List(qdata.shape), **kwargs) + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: List[int], + act_quant_kwargs=None, + dtype=None, + ): + super().__init__() + self.qdata = qdata + self.scale = scale + self.block_size = block_size + self.act_quant_kwargs = act_quant_kwargs + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, " + f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" + ) + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + ): + if w.dim() != 2 or len(block_size) != 2: + raise ValueError("Expected 2D tensor and block_size length 2") + + if act_quant_kwargs and act_quant_kwargs.static_scale is not None: + # INT8 × INT8 (static) + scale = act_quant_kwargs.static_scale + zero_point = torch.zeros_like(scale, dtype=torch.int8) + else: + # INT8 × INT8 (dynamic): compute scale at runtime + scale, zero_point = choose_qparams_affine( + input=w, + mapping_type=MappingType.SYMMETRIC, + block_size=tuple(block_size), + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=w.dtype, + zero_point_dtype=torch.int8, + ) + + int_data = quantize_affine( + w, + block_size=tuple(block_size), + scale=scale, + zero_point=zero_point, + output_dtype=torch.int8, + ) + + return cls( + int_data, + scale, + block_size, + act_quant_kwargs=act_quant_kwargs, + dtype=w.dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Dequantize int8 tensor to floating point""" + + qdata_fp = self.qdata.to(output_dtype) + # Reshape scale to broadcast if granularity is block-wise + scale_expanded = _maybe_expand_scale_to_tensor_shape( + self.scale, self.qdata.shape + ) + return qdata_fp * scale_expanded.to(output_dtype) + + +implements = Int8Tensor.implements +implements_torch_function = Int8Tensor.implements_torch_function + + +@implements([aten.dequantize.self]) +def _(func, types, args, kwargs): + """dequantization: int8 -> float""" + return args[0].dequantize() + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + """quantization: dynamic, static, weight-only int8 quantization""" + activation_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + assert isinstance(weight_tensor, Int8Tensor), ( + f"Expected weight to be Int8Tensor, got {type(weight_tensor)}" + ) + + if weight_tensor.act_quant_kwargs is not None: + if not isinstance(activation_tensor, Int8Tensor): + # Activation quantization + activation_tensor = _choose_quant_func_and_quantize_tensor( + activation_tensor, weight_tensor.act_quant_kwargs + ) + + x_vals = activation_tensor.qdata + x_scales = activation_tensor.scale + w_vals_t = weight_tensor.qdata.contiguous().t() + w_scales = weight_tensor.scale + + tmp = x_vals.reshape(-1, x_vals.shape[-1]) + x_scales_dtype = x_scales.dtype + + # Cast fp16 scale to float + intermediate_dtype = ( + torch.float if x_scales_dtype == torch.half else x_scales_dtype + ) + # Note: CUDA doesn't support int32/int64 matmul, so we convert to float + # Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int' + # This may introduce minor numerical differences compared to int arithmetic + y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype)) + y_dot_scaled = y_dot * x_scales.reshape(-1, 1).to(intermediate_dtype) + + result = (y_dot_scaled * w_scales).reshape( + *x_vals.shape[:-1], y_dot_scaled.shape[-1] + ) + result = result.to(activation_tensor.dtype) + else: + # FP × INT8 (weight-only) + result = func( + activation_tensor, weight_tensor.dequantize(activation_tensor.dtype), None + ) + + return result + bias if bias is not None else result + + +@implements([aten.slice.Tensor]) +def _(func, types, args, kwargs): + """Slice operation for Int8Tensor""" + tensor, dim, start, end, step = ( + args[0], + args[1], + args[2], + args[3], + args[4] if len(args) > 4 else 1, + ) + + assert dim in (0, 1), f"Only dim 0 or 1 supported, got {dim}" + + if end >= tensor.shape[dim]: + end = tensor.shape[dim] + + # Always slice the qdata + sliced_qdata = func(tensor.qdata, dim, start, end, step) + + if tensor.scale.numel() == 1: + # Per-tensor quantization - scale doesn't change + sliced_scale = tensor.scale + elif dim < tensor.scale.ndim and tensor.scale.shape[dim] > 1: + # Block-wise quantization - need to slice the scale appropriately + sliced_scale = func(tensor.scale, dim, start, end, step) + else: + sliced_scale = tensor.scale + + # adjust block_size since the shape has changed, block_size[i] should not be greater than shape[i] + block_size = List(tensor.block_size) + + for i in range(len(block_size)): + block_size[i] = min(block_size[i], sliced_qdata.shape[i]) + + return return_and_correct_aliasing( + func, + args, + kwargs, + Int8Tensor( + sliced_qdata, + sliced_scale, + block_size, + tensor.act_quant_kwargs, + tensor.dtype, + ), + ) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + self, dim, index = args + assert dim == 0, f"Only dim=0 supported, got {dim}" + + selected_scale = self.scale if self.scale.ndim == 0 else self.scale[index] + + return return_and_correct_aliasing( + func, + args, + kwargs, + Int8Tensor( + self.qdata[index], + selected_scale, + self.block_size, + self.act_quant_kwargs, + self.dtype, + ), + ) + + +Int8Tensor.__module__ = "torchao.quantization" +torch.serialization.add_safe_globals([Int8Tensor, QuantizeTensorToInt8Kwargs])