-
Notifications
You must be signed in to change notification settings - Fork 348
Add Int8Tensor for clearer interface #3038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Introduce new tensor subclass API for int8 quantization with clearer interface. The main change can be summarized to the following: - Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling - New: Direct int8 tensor with qdata, scale, and zero_point attributes Test plan: test/quantization/quantize_/workflows/int8/test_int8_tensor.py Future plan: Implement block-wise quantization using `block_size` parameter
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3038
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
can you add a version 2 and expose this tensor through ao/torchao/quantization/quant_api.py Line 1497 in 8525185
ao/torchao/quantization/quant_api.py Line 1752 in 8525185
|
result = result.to(scale.dtype) * scale | ||
result = result.view(*input_tensor.shape[:-1], -1) | ||
else: | ||
# FP × INT8 (static) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also this is the code for weight only quant I think:
ao/torchao/dtypes/uintx/plain_layout.py
Line 250 in 122b307
def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done at 9383550 , thanks for pointing it out.
raise ValueError("Expected 2D tensor and block_size length 2") | ||
|
||
# Rounding function from high precision dtype | ||
scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like block_size is not used? why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can checkout
ao/torchao/dtypes/uintx/plain_layout.py
Line 232 in 8c5c33e
def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): |
also this should be using these quant primitive ops:
ao/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py
Lines 79 to 97 in 8c5c33e
scale, zero_point = choose_qparams_affine( | |
input=preprocessed_w, | |
mapping_type=MappingType.SYMMETRIC, | |
block_size=block_size, | |
target_dtype=target_dtype, | |
quant_min=quant_min, | |
quant_max=quant_max, | |
eps=1e-6, | |
) | |
wq = quantize_affine( | |
input=preprocessed_w, | |
block_size=block_size, | |
scale=scale, | |
zero_point=zero_point, | |
output_dtype=target_dtype, | |
quant_min=quant_min, | |
quant_max=quant_max, | |
) |
ao/torchao/quantization/quant_api.py
Line 1566 in 8c5c33e
new_weight = to_affine_quantized_intx( |
ao/torchao/dtypes/affine_quantized_tensor.py
Line 325 in 8c5c33e
scale, zero_point = choose_qparams_affine( |
this might require a bit too much context, let me know if you would like us to take over
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, surely want to take over! Drafted this PR for those updates, but will look into it today (6 hours later)
btw, version 2 is updated at c53dad0 (version 1 is default)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please rebase, and let me know when this is ready for review again @namgyu-youn
) | ||
else: | ||
assert config.version == 2, f"Unexpected version: {config.version}" | ||
block_size = [weight.shape[0], weight.shape[1]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be the same as L1393 I think, you can extract L1390-L1393 out of the first if branch and use that I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't dividing logics much safer and easier to deprecate old API in the future? Other APIs like _float8_weight_only_quant_tensor
also have been used with this convention without a common branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's fine to duplicate I think, but the current code for block_size doesn't support 3d though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay then I will keep this branch and update the assert for 3D check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh actually we already doing 3D-check at from_hp()
, by using
if w.dim() != 2 or len(block_size) != 2:
raise ValueError("Expected 2D tensor and block_size length 2")
torchao/quantization/quant_api.py
Outdated
else: | ||
quantized_weight = Int8Tensor.from_hp( | ||
weight, | ||
block_size=get_weight_block_size(weight), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can calculate block_size
outside of the if/else
elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): | ||
return Int8Tensor.from_hp( | ||
tensor, | ||
quant_kwargs.block_size or [1, tensor.shape[-1]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not make block_size mandatory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one is still not resolved yet
block_size (Optional[list[int]]): block size for quantization granularity | ||
""" | ||
|
||
block_size: Optional[list[int]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was wrong type hint because api can't work without granularity, mandatory (not-optional) should be right.
# Reshape 1D scale to [N, 1] for broadcasting with [N, K] qdata | ||
if scale.ndim == 1: | ||
scale = scale.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed?
) | ||
|
||
|
||
@implements(aten.transpose.int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need this yet I think, we can remove for now and add later when needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you tell me why there is no need to support transposition for quantized tensors? I thought it was just a type of tensor. If we remove this, how can users transpose it like Tensor.transpose()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just haven't seen people using it yet, I think we should implement as little as possible to keep maintainence burden low
if dim == 0 and tensor.scale.ndim >= 1: | ||
sliced_scale = aten.slice.Tensor(tensor.scale, 0, start, end, step) | ||
|
||
sliced_shape = list( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not get the shape from sliced tensor directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you check
@implements(aten.slice.Tensor) |
return args[0].dequantize() | ||
|
||
|
||
@implements([torch.nn.functional.linear, aten.linear.default]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: implements is refactored now: https://github.com/pytorch/ao/pull/2866/files
|
||
if not isinstance(aten_ops_or_torch_fns, (list, tuple)): | ||
aten_ops_or_torch_fns = [aten_ops_or_torch_fns] | ||
def _implements_torch_function(cls, torch_fns): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why these changes? is there some issue with rebase?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was no merge conflict, so I overwrote this file, but I regret this commit; let me know if 0a45f90 should be reverted.
if not isinstance(activation_tensor, Int8Tensor): | ||
if weight_tensor.act_quant_kwargs.static_scale is not None: | ||
# INT8 × INT8 (static): symmetric quantization only | ||
static_scale = weight_tensor.act_quant_kwargs.static_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK if this is needed I think it should be included in _choose_quant_func_and_quantize_tensor
as well?
implements_torch_function = Int8Tensor.implements_torch_function | ||
|
||
|
||
@implements([aten.dequantize.self]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed? if not we should remove for now
if scale.numel() > 1 and scale.shape != qdata_fp.shape: | ||
scale = scale.view(*scale.shape, *[1] * (qdata_fp.ndim - scale.ndim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is needed for block-level granularity. For example,
- Row-wise: If scale shape is (64, 1) and w_q (quantized weight shape) is (256, 512), we can naturally broadcast them
- Channel-wise: If scale shape is (512,) and w_q is (256, 512), we can naturally broadcast them
- Block-size granularity: If scale shape is (32, 64) and w_q is (256, 512), we have to rescale to broadcast them.
But we can also reuse _maybe_expand_scale_to_tensor_shape
, similar to:
ao/torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Lines 149 to 154 in 4b79f9e
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: | |
if output_dtype is None: | |
output_dtype = self.dtype | |
qdata, scale = self.qdata, self.scale | |
return _dequantize_affine_float8(qdata, scale, output_dtype) |
cls: type, | ||
qdata: torch.Tensor, | ||
scale: torch.Tensor, | ||
block_size: list[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I remember list
has a higher python version requirements, so probably better to change this to List
from typing
I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it is only for List
, not for Dict
, Tuple
, etc.?
return module | ||
|
||
|
||
def _unwrap_float8_linear(module: Float8Linear) -> nn.Linear: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some rebase issue?
Updated log:
To reviewers: |
Summary:
Introduce new tensor subclass API for int8 quantization with clearer interface.
The main change can be summarized to the following:
AffineQuantizedTensor
) with separate layout handlingRelated Issue/PR: #3012 (comment) #2752
Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py