Skip to content

Commit afb4645

Browse files
AWQ+GPTQ (quic#101)
* Awq feature (quic#100) * added preprocess layer before loading quantized awq weights Signed-off-by: Onkar Chougule <[email protected]> * added onnx export Signed-off-by: Onkar Chougule <[email protected]> * added ScaledActivation class Signed-off-by: Onkar Chougule <[email protected]> * refactoring the code to right places and added one single test for now Signed-off-by: Onkar Chougule <[email protected]> * cleaned code Signed-off-by: Onkar Chougule <[email protected]> * added proper tests, added decorator for updating quantizers, cleaned code Signed-off-by: Onkar Chougule <[email protected]> * fixed CLI Signed-off-by: Onkar Chougule <[email protected]> * added auto file for decorator Signed-off-by: Onkar Chougule <[email protected]> --------- Signed-off-by: Onkar Chougule <[email protected]> * bugfix for tests Signed-off-by: Onkar Chougule <[email protected]> * fixed tests for AWQ model Signed-off-by: Onkar Chougule <[email protected]> * Adding support for GPTQ models (quic#103) * Adding support for gptq models Signed-off-by: Amit Raj <[email protected]> * Code cleaning and formating Signed-off-by: Amit Raj <[email protected]> * ruff format and fixed some bug Signed-off-by: Amit Raj <[email protected]> * Added tests for gptq Signed-off-by: Amit Raj <[email protected]> * Bug-fix-1 Signed-off-by: Amit Raj <[email protected]> * fixed bugs-2 Signed-off-by: Amit Raj <[email protected]> * fixed bug-3 Signed-off-by: Amit Raj <[email protected]> * Added docstring Signed-off-by: Amit Raj <[email protected]> * Addressed comments Signed-off-by: Amit Raj <[email protected]> * Addressed comments Signed-off-by: Amit Raj <[email protected]> * fixed bugs-3 Signed-off-by: Amit Raj <[email protected]> * ruff check and format Signed-off-by: Amit Raj <[email protected]> * Addressed comments-3 Signed-off-by: Amit Raj <[email protected]> --------- Signed-off-by: Amit Raj <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> * added liscence at top for missing file Signed-off-by: Onkar Chougule <[email protected]> * added export_and_compile and fixed bugs Signed-off-by: Onkar Chougule <[email protected]> * removed GPTQ test Signed-off-by: Onkar Chougule <[email protected]> * removed threading from pytest Signed-off-by: Onkar Chougule <[email protected]> --------- Signed-off-by: Onkar Chougule <[email protected]> Signed-off-by: Amit Raj <[email protected]> Co-authored-by: Amit Raj <[email protected]>
1 parent 0ef6829 commit afb4645

File tree

19 files changed

+1384
-60
lines changed

19 files changed

+1384
-60
lines changed

QEfficient/base/common.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class QEFF_MODEL_TYPE(Enum):
3131

3232
CAUSALLM = "LLM"
3333
DIFFUSION = "DIFFUSION"
34-
AWQ = "AWQ"
3534

3635

3736
MODEL_TYPE_TO_QEFF_AUTO_MODEL_MAP: Dict[QEFF_MODEL_TYPE, Type[QEFFBaseModel]] = {
@@ -56,15 +55,7 @@ def get_hf_model_type(hf_model_path: str) -> QEFF_MODEL_TYPE:
5655
)
5756

5857
if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING:
59-
# FIXME: Add logic to handle if quantization config is stored in separate quant_config.json outside of config, also create a separate function for this and below lines
60-
quant_config = getattr(config, "quantization_config", getattr(config, "quant_config", None))
61-
if quant_config is not None:
62-
if quant_config.get("quant_method", None) == "awq":
63-
return QEFF_MODEL_TYPE.AWQ
64-
else:
65-
raise NotImplementedError(f"current model type is not yet supported {type(config)}")
66-
else:
67-
return QEFF_MODEL_TYPE.CAUSALLM
58+
return QEFF_MODEL_TYPE.CAUSALLM
6859
else:
6960
raise NotImplementedError(f"model type {type(config)} is not yet supported")
7061

QEfficient/base/pytorch_transforms.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,35 @@ def register(cls, from_module: Type[nn.Module], to_module: Type[nn.Module]):
5555
FlashAttention.register(LLamaAttention, LlamaFlashAttention)
5656
"""
5757
cls._module_mapping[from_module] = to_module
58+
59+
60+
class ModuleMutatorTransform(PytorchTransform):
61+
"""Serves as base class for any transform that mutates pytorch module in any way.
62+
Mutate here mean, we initialize a new pytorch module object using info from original module and
63+
replace original module with new module.
64+
65+
Raises:
66+
NotImplementedError: Not supposed to use directly, Create a subclass and implement mutate method and assign a valid nn.Module class to _match_class variable.
67+
"""
68+
69+
_match_class: nn.Module
70+
71+
@classmethod
72+
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
73+
transformed = False
74+
for name, module in model.named_children():
75+
if isinstance(module, cls._match_class):
76+
setattr(model, name, cls.mutate(module, model))
77+
transformed = True
78+
else:
79+
cls.apply(module)
80+
81+
if isinstance(model, cls._match_class):
82+
model = cls.mutate(model, None)
83+
transformed = True
84+
85+
return model, transformed
86+
87+
@classmethod
88+
def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
89+
raise NotImplementedError("Please implement your own method by inheriting this class")

QEfficient/customop/matmulnbits.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import math
9+
10+
import torch
11+
from torch import nn
12+
13+
14+
class QuantLinearTorchFunction(torch.autograd.Function):
15+
@staticmethod
16+
def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features):
17+
input_tuple = (x, qself_qweight, qself_scales, qself_qzeros)
18+
input_tuple += (g_idx,) if g_idx is not None else ()
19+
return g.op(
20+
"com.microsoft::MatMulNBits",
21+
*input_tuple,
22+
outputs=1,
23+
K_i=in_features,
24+
N_i=out_features,
25+
bits_i=bits,
26+
block_size_i=group_size,
27+
)
28+
29+
@staticmethod
30+
def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features):
31+
if torch.onnx.is_in_onnx_export():
32+
return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype).float()
33+
fp_weight = dequantize_blockwise_bits(
34+
qself_qweight, qself_scales, qself_qzeros, bits, group_size, g_idx, in_features, out_features
35+
)[0].float()
36+
37+
return torch.matmul(x.float(), fp_weight.T.float())
38+
39+
40+
def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, group_size, g_idx, rows, cols):
41+
if bits != 4:
42+
raise ValueError("Only bits=4 is supported for executing quantized model")
43+
if group_size != 128:
44+
raise ValueError("Only group_size=128 is supported for executing quantized model")
45+
expand_quant_value = (quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F
46+
expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1)
47+
aligned_scale = scale.reshape(*quant_values.shape[:-1], 1)
48+
if zero_point.dtype == scale.dtype:
49+
expand_zero_point = zero_point.reshape(*quant_values.shape[:-1], -1)
50+
else:
51+
expand_zero_point = (zero_point.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F
52+
try:
53+
expand_zero_point = expand_zero_point.reshape(*quant_values.shape[:-1], -1)
54+
# FIXME: remove try-except
55+
except RuntimeError:
56+
expand_zero_point = expand_zero_point.reshape(quant_values.shape[0], -1, 1)
57+
expand_zero_point = expand_zero_point[:, : quant_values.shape[1]]
58+
if g_idx is not None and g_idx[:32].sum().item() != 0:
59+
float_values = (
60+
(expand_quant_value.reshape(expand_quant_value.shape[0], -1) - expand_zero_point[:, g_idx, 0])
61+
* aligned_scale[:, g_idx, 0]
62+
).to(scale.dtype)
63+
else:
64+
float_values = ((expand_quant_value - expand_zero_point) * aligned_scale).to(scale.dtype)
65+
float_values = float_values.reshape(cols, -1)
66+
if rows != float_values.shape[-1]:
67+
float_values = float_values[:, :rows]
68+
expand_zero_point = expand_zero_point[:, :rows]
69+
if expand_zero_point.ndim == 3:
70+
expand_zero_point = expand_zero_point.squeeze(-1)
71+
if aligned_scale.ndim == 3:
72+
aligned_scale = aligned_scale.squeeze(-1)
73+
74+
return float_values, expand_zero_point, aligned_scale
75+
76+
77+
class QuantLinearORT(nn.Module):
78+
def __init__(self, bits, group_size, in_features, out_features, bias):
79+
super().__init__()
80+
if bits not in [2, 3, 4, 5, 6, 7, 8]:
81+
raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.")
82+
self.in_features = in_features
83+
self.out_features = out_features
84+
self.bits = bits
85+
self.group_size = group_size if group_size != -1 else in_features
86+
self.act_order = None
87+
88+
q_rows = in_features // self.group_size
89+
self.register_buffer(
90+
"qweight",
91+
torch.zeros((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8),
92+
)
93+
self.register_buffer(
94+
"qzeros",
95+
torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8),
96+
)
97+
self.register_buffer(
98+
"scales", torch.zeros((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16)
99+
)
100+
self.register_buffer(
101+
"g_idx", torch.tensor([i // self.group_size for i in range(in_features)], dtype=torch.int32)
102+
)
103+
if bias:
104+
self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16))
105+
else:
106+
self.bias = None
107+
108+
def quant_weight(self, weight, scales, zeros, g_idx):
109+
scale_zeros = zeros * scales
110+
scale_mat = scales[g_idx]
111+
scale_zeros_mat = scale_zeros[g_idx]
112+
int_weight_T = torch.round(((weight + scale_zeros_mat) / scale_mat).float()).to(torch.int)
113+
return int_weight_T
114+
115+
def pack_on_device(self, int_weight, int_zeros):
116+
if self.bits != 4:
117+
raise ValueError("only 4bit is supported by ONNXRUNTIME for now.")
118+
119+
# Order of groups
120+
self.act_order = self.g_idx[: self.group_size // self.bits].sum().item() != 0
121+
122+
intzeros_pt = int_zeros.T if int_zeros.dtype == self.scales.dtype else int_zeros.T.byte()
123+
scales_pt = self.scales.T.to(int_weight.device)
124+
intweight_pt = int_weight.byte()
125+
126+
block_size = self.group_size
127+
rows, cols = intweight_pt.shape
128+
blob_size = block_size // 2
129+
k_blocks = (rows + block_size - 1) // block_size
130+
padded_rows = k_blocks * block_size
131+
pad_len = padded_rows - rows
132+
if pad_len > 0:
133+
intweight_pt = torch.nn.functional.pad(intweight_pt, (0, 0, 0, pad_len), "constant", 0)
134+
intzeros_pt = torch.nn.functional.pad(intzeros_pt, (0, intzeros_pt.shape[-1] & 1, 0, 0), "constant", 0)
135+
136+
# Pack zeros if they are not float
137+
if int_zeros.dtype != self.scales.dtype:
138+
intzeros_pt = (intzeros_pt[:, 0::2]) | (intzeros_pt[:, 1::2] << 4)
139+
intzeros_pt = intzeros_pt.reshape(-1)
140+
141+
# Pack weights
142+
intweight_pt_T = int_weight.T
143+
intweight_pt_T = (intweight_pt_T[:, 0::2]) | (intweight_pt_T[:, 1::2] << 4)
144+
intweight_pt_T = intweight_pt_T.reshape(cols, k_blocks, blob_size)
145+
146+
scales_pt = scales_pt.reshape(-1)
147+
148+
# Validation checks
149+
if (self.qweight.shape != intweight_pt_T.shape) and (
150+
self.qzeros.shape == intzeros_pt.shape or self.qzeros.dtype != intzeros_pt.dtype
151+
):
152+
raise RuntimeError("Something went wrong while packing the weights in QuantLinearORT module")
153+
154+
# Assign buffers
155+
self.scales = scales_pt.float()
156+
self.qweight = intweight_pt_T.byte() # Convert to uint8
157+
if int_zeros.dtype != self.scales.dtype:
158+
self.qzeros = intzeros_pt.byte() # Convert to uint8
159+
else:
160+
self.qzeros = intzeros_pt
161+
162+
def pack(self, linear, scales, zeros, g_idx=None):
163+
layer_weight = linear.weight.data
164+
self.scales = scales.T
165+
self.g_idx = g_idx.clone()
166+
int_weight = self.quant_weight(layer_weight.T, scales.T, zeros.T, g_idx)
167+
return self.pack_on_device(int_weight, zeros.T)
168+
169+
def forward(self, inputs):
170+
out = QuantLinearTorchFunction().apply(
171+
inputs,
172+
self.qweight,
173+
self.scales,
174+
self.qzeros,
175+
self.g_idx if self.act_order else None,
176+
self.bits,
177+
self.group_size,
178+
self.in_features,
179+
self.out_features,
180+
)
181+
out = out + self.bias if self.bias is not None else out
182+
return out

QEfficient/exporter/export_hf_to_cloud_ai_100.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def qualcomm_efficient_converter(
443443
model_kv = model_kv if model_kv.is_transformed else QEfficient.transform(model_kv) if kv else model_kv
444444

445445
if onnx_dir_path is None:
446-
model_card_dir = os.path.join(QEFF_MODELS_DIR, str(model_name))
446+
model_card_dir = os.path.join(QEFF_MODELS_DIR, str(model_kv.model_card_name))
447447
onnx_dir_path = os.path.join(model_card_dir, "onnx")
448448
os.makedirs(onnx_dir_path, exist_ok=True)
449449

0 commit comments

Comments
 (0)