Skip to content

Conversation

namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Sep 21, 2025

Summary:
Introduce new tensor subclass API for int8 quantization with clearer interface.

The main change can be summarized to the following:

  • Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling
  • New: Direct int8 tensor with scaling factor and zero point

Related Issue/PR: #3012 (comment) #2752

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Introduce new tensor subclass API for int8 quantization with clearer interface.

The main change can be summarized to the following:
- Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling
- New: Direct int8 tensor with qdata, scale, and zero_point attributes

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Future plan:
Implement block-wise quantization using `block_size` parameter
Copy link

pytorch-bot bot commented Sep 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3038

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 21, 2025
@jerryzh168
Copy link
Contributor

can you add a version 2 and expose this tensor through

class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
? similar to
class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):

@namgyu-youn namgyu-youn changed the title Add Int8PlainInt8Tensor for clearer interface Add Int8Tensor for clearer interface Sep 23, 2025
result = result.to(scale.dtype) * scale
result = result.view(*input_tensor.shape[:-1], -1)
else:
# FP × INT8 (static)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this is the code for weight only quant I think:

def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):

Copy link
Contributor Author

@namgyu-youn namgyu-youn Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done at 9383550 , thanks for pointing it out.

raise ValueError("Expected 2D tensor and block_size length 2")

# Rounding function from high precision dtype
scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like block_size is not used? why is that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can checkout

def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias):
for expected granularity

also this should be using these quant primitive ops:

scale, zero_point = choose_qparams_affine(
input=preprocessed_w,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=1e-6,
)
wq = quantize_affine(
input=preprocessed_w,
block_size=block_size,
scale=scale,
zero_point=zero_point,
output_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
)
, arguments can be found by tracing through the code path for int8 in
new_weight = to_affine_quantized_intx(
and
scale, zero_point = choose_qparams_affine(

this might require a bit too much context, let me know if you would like us to take over

Copy link
Contributor Author

@namgyu-youn namgyu-youn Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, surely want to take over! Drafted this PR for those updates, but will look into it today (6 hours later)

btw, version 2 is updated at c53dad0 (version 1 is default)

@namgyu-youn namgyu-youn marked this pull request as draft September 28, 2025 13:23
@namgyu-youn namgyu-youn marked this pull request as ready for review September 30, 2025 06:09
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rebase, and let me know when this is ready for review again @namgyu-youn

)
else:
assert config.version == 2, f"Unexpected version: {config.version}"
block_size = [weight.shape[0], weight.shape[1]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be the same as L1393 I think, you can extract L1390-L1393 out of the first if branch and use that I think

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't dividing logics much safer and easier to deprecate old API in the future? Other APIs like _float8_weight_only_quant_tensor also have been used with this convention without a common branch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's fine to duplicate I think, but the current code for block_size doesn't support 3d though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay then I will keep this branch and update the assert for 3D check.

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually we already doing 3D-check at from_hp(), by using

if w.dim() != 2 or len(block_size) != 2:
    raise ValueError("Expected 2D tensor and block_size length 2")

else:
quantized_weight = Int8Tensor.from_hp(
weight,
block_size=get_weight_block_size(weight),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can calculate block_size outside of the if/else

elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs):
return Int8Tensor.from_hp(
tensor,
quant_kwargs.block_size or [1, tensor.shape[-1]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not make block_size mandatory?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is still not resolved yet

block_size (Optional[list[int]]): block size for quantization granularity
"""

block_size: Optional[list[int]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this optional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was wrong type hint because api can't work without granularity, mandatory (not-optional) should be right.

Comment on lines 144 to 146
# Reshape 1D scale to [N, 1] for broadcasting with [N, K] qdata
if scale.ndim == 1:
scale = scale.unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

)


@implements(aten.transpose.int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need this yet I think, we can remove for now and add later when needed

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you tell me why there is no need to support transposition for quantized tensors? I thought it was just a type of tensor. If we remove this, how can users transpose it like Tensor.transpose() ?

Copy link
Contributor

@jerryzh168 jerryzh168 Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just haven't seen people using it yet, I think we should implement as little as possible to keep maintainence burden low

if dim == 0 and tensor.scale.ndim >= 1:
sliced_scale = aten.slice.Tensor(tensor.scale, 0, start, end, step)

sliced_shape = list(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not get the shape from sliced tensor directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check

? I'm not sure if the current implementation is enough to cover all cases actually

@namgyu-youn namgyu-youn requested a review from jerryzh168 October 7, 2025 10:25
return args[0].dequantize()


@implements([torch.nn.functional.linear, aten.linear.default])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: implements is refactored now: https://github.com/pytorch/ao/pull/2866/files


if not isinstance(aten_ops_or_torch_fns, (list, tuple)):
aten_ops_or_torch_fns = [aten_ops_or_torch_fns]
def _implements_torch_function(cls, torch_fns):
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why these changes? is there some issue with rebase?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was no merge conflict, so I overwrote this file, but I regret this commit; let me know if 0a45f90 should be reverted.

if not isinstance(activation_tensor, Int8Tensor):
if weight_tensor.act_quant_kwargs.static_scale is not None:
# INT8 × INT8 (static): symmetric quantization only
static_scale = weight_tensor.act_quant_kwargs.static_scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK if this is needed I think it should be included in _choose_quant_func_and_quantize_tensor as well?

implements_torch_function = Int8Tensor.implements_torch_function


@implements([aten.dequantize.self])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed? if not we should remove for now

Comment on lines 142 to 143
if scale.numel() > 1 and scale.shape != qdata_fp.shape:
scale = scale.view(*scale.shape, *[1] * (qdata_fp.ndim - scale.ndim))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed for block-level granularity. For example,

  1. Row-wise: If scale shape is (64, 1) and w_q (quantized weight shape) is (256, 512), we can naturally broadcast them
  2. Channel-wise: If scale shape is (512,) and w_q is (256, 512), we can naturally broadcast them
  3. Block-size granularity: If scale shape is (32, 64) and w_q is (256, 512), we have to rescale to broadcast them.

But we can also reuse _maybe_expand_scale_to_tensor_shape, similar to:

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype
qdata, scale = self.qdata, self.scale
return _dequantize_affine_float8(qdata, scale, output_dtype)

cls: type,
qdata: torch.Tensor,
scale: torch.Tensor,
block_size: list[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I remember list has a higher python version requirements, so probably better to change this to List from typing I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it is only for List, not for Dict, Tuple, etc.?

return module


def _unwrap_float8_linear(module: Float8Linear) -> nn.Linear:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some rebase issue?

@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Oct 17, 2025

Updated log:

To reviewers:
Unfortunately, I can't build and run local tests, caused by #2919, after trying downgrade and gradual installation. Please feel free to direct commit if test_int8_tensor.py fails.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants