Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/quantization_overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ First we want to lay out the torchao stack::

Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
---------------------------------------------------------------------------------------------
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor
---------------------------------------------------------------------------------------------
Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
---------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma
- scaled int4
- preshuffled (special format to optimize for loading)
- float8 act + int4 weight dynamic quantization and int4 weight only quantization
* - Int8Tensor
- plain

.. note::
We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options.
Expand Down
218 changes: 218 additions & 0 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest
from typing import Tuple

import torch
from torch.testing._internal import common_utils

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
PerRow,
PerTensor,
quantize_,
)
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
Int8Tensor,
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.utils import compute_error
from torchao.testing.utils import TorchAOIntegrationTestCase


# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged
class ToyTwoLinearModel(torch.nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
has_bias=False,
dtype=None,
device=None,
):
super().__init__()
self.dtype = dtype
self.device = device
self.linear1 = torch.nn.Linear(
input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device
)
self.linear2 = torch.nn.Linear(
hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device
)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.instantiate_parametrized_tests
class TestInt8Tensor(TorchAOIntegrationTestCase):
def setUp(self):
super().setUp()
torch.manual_seed(42)
self.weight_fp = torch.randn(4, 3, dtype=torch.bfloat16)
self.input_fp = torch.randn(4, 3, dtype=torch.bfloat16)
self.bias = torch.randn(4, dtype=torch.bfloat16)
self.block_size = [4, 3]

def test_creation_and_attributes(self):
"""Test tensor creation, dtypes, and ranges"""
tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size)

self.assertEqual(tensor.shape, (4, 3))
self.assertEqual(tensor.qdata.dtype, torch.int8)
self.assertTrue(
torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127)
)

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
@common_utils.parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 64, 256),
],
)
@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
def test_int8_linear_variants(
self,
dtype: torch.dtype,
sizes: Tuple,
config,
):
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

# Create a linear layer
m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda")
m_q = copy.deepcopy(m)

# Quantize
quantize_(m_q, config)

output_original = m(input_tensor)
output_quantized = m_q(input_tensor)

error = compute_error(output_original, output_quantized)
assert compute_error(output_original, output_quantized) > 20, (
f"Quantization error is too high got a SQNR of {error}"
)

def test_linear_operations(self):
"""Test fp+int8 and int8+int8 linear ops"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not int8+int8 I think? this is weight only quant

weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size)
input_q8 = Int8Tensor.from_hp(self.input_fp, self.block_size)

reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias)
result_fp = torch.nn.functional.linear(self.input_fp, weight_q8, self.bias)
result_q8 = torch.nn.functional.linear(input_q8, weight_q8, self.bias)

self.assertEqual(result_fp.shape, reference.shape)
self.assertEqual(result_q8.shape, reference.shape)
self.assertTrue(compute_error(result_fp, reference) > 10)
self.assertTrue(compute_error(result_q8, reference) > 10)

def test_dynamic_quantization(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove these 2 tests actually, since they are already tested in test_int8_linear_variants

"""Test dynamic activation quantization"""
weight_q8_dynamic = Int8Tensor.from_hp(
self.weight_fp,
self.block_size,
act_quant_kwargs=QuantizeTensorToInt8Kwargs(),
)

reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias)
result_dynamic = torch.nn.functional.linear(
self.input_fp, weight_q8_dynamic, self.bias
)

self.assertEqual(result_dynamic.shape, reference.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: probably add a test for compute_error comparing floating point weight and int8+int8 weight as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't we already comparing 1) bflot16 vs. int8-quant, 2) float16 vs. int8-quant by dtype parameterization?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this test should be removed


@unittest.skip("granularity parameter not supported in current API")
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
def test_slice_preserves_aliasing(self, granularity):
config = Int8DynamicActivationInt8WeightConfig(
granularity=granularity, version=2
)
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
)
quantize_(l, config)
param = l.weight
param_data = param.data
param_data = param_data.narrow(0, 0, 512)
# Making sure the aliasing is preserved in sliced quantized Tensor
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr()
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize("device", ["cpu", "cuda"])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_slice(self, config, device, dtype):
"""Test tensor slicing"""
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
quantize_(dummy, config)

weight1 = dummy.weight.clone().narrow(0, 0, 64)
weight2 = dummy.weight.clone().narrow(1, 0, 128)

self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64))
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128))
Comment on lines +180 to +181
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add assert for scale as well?


def test_transpose(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this used anywhere? for most of the tensors we actually don't support transpose so far, we tend to add this only when need

"""Test transpose operation"""
weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size)
transposed = weight_q8.transpose(0, 1)

self.assertEqual(transposed.shape, (3, 4))
self.assertEqual(transposed.block_size, [3, 4])

def test_select(self):
"""Test select operation"""
weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size)
selected = weight_q8.select(0, 0)

self.assertEqual(selected.shape, (3,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test the data as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can follow this:


with self.assertRaises(AssertionError):
weight_q8.select(1, 0)

def test_error_handling_and_dequant(self):
"""Test input validation and dequantization accuracy"""
with self.assertRaises((AssertionError, ValueError, RuntimeError)):
Int8Tensor.from_hp(torch.randn(5), [1])

with self.assertRaises((AssertionError, ValueError, RuntimeError)):
Int8Tensor.from_hp(self.weight_fp, [1])

test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16)
tensor = Int8Tensor.from_hp(test_data, [1, 2])

dequantized = torch.ops.aten.dequantize.self(tensor)
self.assertEqual(dequantized.shape, test_data.shape)
self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1)


if __name__ == "__main__":
common_utils.run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Int8Tensor,
IntxOpaqueTensor,
IntxUnpackedToInt8Tensor,
)
Expand Down Expand Up @@ -170,6 +171,7 @@
"IntxOpaqueTensor",
"IntxUnpackedToInt8Tensor",
"Int4TilePackedTo4dTensor",
"Int8Tensor",
"Float8Tensor",
"Int4OpaqueTensor",
# smooth quant - subject to change
Expand Down
83 changes: 54 additions & 29 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Int8Tensor,
IntxChooseQParamsAlgorithm,
IntxOpaqueTensor,
IntxPackingFormat,
Expand Down Expand Up @@ -1362,10 +1363,12 @@ class Int8WeightOnlyConfig(AOBaseConfig):
Otherwise, applies per-group quantization with the specified group size.
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
for better performance with this quantization scheme.
version - Version of the config to use. Version 1 uses AffineQuantization for quantization,
"""

group_size: Optional[int] = None
set_inductor_config: bool = True
version: int = 1

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
Expand All @@ -1376,22 +1379,30 @@ def __post_init__(self):


def _int8_weight_only_quantize_tensor(weight, config):
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
group_size = config.group_size
if group_size is None:
group_size = weight.shape[-1]
block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size])
new_weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype,
)
if config.version == 1:
warnings.warn(
"Config Deprecation: version 1 of Int8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details"
)
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
group_size = config.group_size
if group_size is None:
group_size = weight.shape[-1]
block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size])
new_weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype,
)
else:
assert config.version == 2, f"Unexpected version: {config.version}"
block_size = [weight.shape[0], weight.shape[1]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be the same as L1393 I think, you can extract L1390-L1393 out of the first if branch and use that I think

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't dividing logics much safer and easier to deprecate old API in the future? Other APIs like _float8_weight_only_quant_tensor also have been used with this convention without a common branch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's fine to duplicate I think, but the current code for block_size doesn't support 3d though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay then I will keep this branch and update the assert for 3D check.

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually we already doing 3D-check at from_hp(), by using

if w.dim() != 2 or len(block_size) != 2:
    raise ValueError("Expected 2D tensor and block_size length 2")

new_weight = Int8Tensor.from_hp(weight, block_size=block_size)
return new_weight


Expand Down Expand Up @@ -1519,12 +1530,14 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
in original precision during decode operations.
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
for better performance with this quantization scheme.
version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Int8Tensor
"""

layout: Optional[Layout] = PlainLayout()
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
weight_only_decode: bool = False
set_inductor_config: bool = True
version: int = 1

def __post_init__(self):
torch._C._log_api_usage_once(
Expand Down Expand Up @@ -1572,19 +1585,31 @@ def get_weight_block_size(x):
else:
input_quant_func = _int8_asymm_per_token_quant

block_size = get_weight_block_size(weight)
new_weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype,
_layout=layout,
zero_point_domain=weight_zero_point_domain,
)
new_weight = to_linear_activation_quantized(new_weight, input_quant_func)
return new_weight
if config.version == 1:
warnings.warn(
"Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details"
)
block_size = get_weight_block_size(weight)
quantized_weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype,
_layout=layout,
zero_point_domain=weight_zero_point_domain,
)
quantized_weight = to_linear_activation_quantized(
quantized_weight, input_quant_func
)
else:
quantized_weight = Int8Tensor.from_hp(
weight,
block_size=get_weight_block_size(weight),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can calculate block_size outside of the if/else

)

return quantized_weight


@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig)
Expand Down
Loading