|  | 
|  | 1 | +# ----------------------------------------------------------------------------- | 
|  | 2 | +# | 
|  | 3 | +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. | 
|  | 4 | +# SPDX-License-Identifier: BSD-3-Clause | 
|  | 5 | +# | 
|  | 6 | +# ----------------------------------------------------------------------------- | 
|  | 7 | + | 
|  | 8 | +import math | 
|  | 9 | + | 
|  | 10 | +import torch | 
|  | 11 | +from torch import nn | 
|  | 12 | + | 
|  | 13 | + | 
|  | 14 | +class QuantLinearTorchFunction(torch.autograd.Function): | 
|  | 15 | +    @staticmethod | 
|  | 16 | +    def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features): | 
|  | 17 | +        input_tuple = (x, qself_qweight, qself_scales, qself_qzeros) | 
|  | 18 | +        input_tuple += (g_idx,) if g_idx is not None else () | 
|  | 19 | +        return g.op( | 
|  | 20 | +            "com.microsoft::MatMulNBits", | 
|  | 21 | +            *input_tuple, | 
|  | 22 | +            outputs=1, | 
|  | 23 | +            K_i=in_features, | 
|  | 24 | +            N_i=out_features, | 
|  | 25 | +            bits_i=bits, | 
|  | 26 | +            block_size_i=group_size, | 
|  | 27 | +        ) | 
|  | 28 | + | 
|  | 29 | +    @staticmethod | 
|  | 30 | +    def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features): | 
|  | 31 | +        if torch.onnx.is_in_onnx_export(): | 
|  | 32 | +            return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype).float() | 
|  | 33 | +        fp_weight = dequantize_blockwise_bits( | 
|  | 34 | +            qself_qweight, qself_scales, qself_qzeros, bits, group_size, g_idx, in_features, out_features | 
|  | 35 | +        )[0].float() | 
|  | 36 | + | 
|  | 37 | +        return torch.matmul(x.float(), fp_weight.T.float()) | 
|  | 38 | + | 
|  | 39 | + | 
|  | 40 | +def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, group_size, g_idx, rows, cols): | 
|  | 41 | +    if bits != 4: | 
|  | 42 | +        raise ValueError("Only bits=4 is supported for executing quantized model") | 
|  | 43 | +    if group_size != 128: | 
|  | 44 | +        raise ValueError("Only group_size=128 is supported for executing quantized model") | 
|  | 45 | +    expand_quant_value = (quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F | 
|  | 46 | +    expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1) | 
|  | 47 | +    aligned_scale = scale.reshape(*quant_values.shape[:-1], 1) | 
|  | 48 | +    if zero_point.dtype == scale.dtype: | 
|  | 49 | +        expand_zero_point = zero_point.reshape(*quant_values.shape[:-1], -1) | 
|  | 50 | +    else: | 
|  | 51 | +        expand_zero_point = (zero_point.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F | 
|  | 52 | +        try: | 
|  | 53 | +            expand_zero_point = expand_zero_point.reshape(*quant_values.shape[:-1], -1) | 
|  | 54 | +        # FIXME: remove try-except | 
|  | 55 | +        except RuntimeError: | 
|  | 56 | +            expand_zero_point = expand_zero_point.reshape(quant_values.shape[0], -1, 1) | 
|  | 57 | +            expand_zero_point = expand_zero_point[:, : quant_values.shape[1]] | 
|  | 58 | +    if g_idx is not None and g_idx[:32].sum().item() != 0: | 
|  | 59 | +        float_values = ( | 
|  | 60 | +            (expand_quant_value.reshape(expand_quant_value.shape[0], -1) - expand_zero_point[:, g_idx, 0]) | 
|  | 61 | +            * aligned_scale[:, g_idx, 0] | 
|  | 62 | +        ).to(scale.dtype) | 
|  | 63 | +    else: | 
|  | 64 | +        float_values = ((expand_quant_value - expand_zero_point) * aligned_scale).to(scale.dtype) | 
|  | 65 | +    float_values = float_values.reshape(cols, -1) | 
|  | 66 | +    if rows != float_values.shape[-1]: | 
|  | 67 | +        float_values = float_values[:, :rows] | 
|  | 68 | +        expand_zero_point = expand_zero_point[:, :rows] | 
|  | 69 | +    if expand_zero_point.ndim == 3: | 
|  | 70 | +        expand_zero_point = expand_zero_point.squeeze(-1) | 
|  | 71 | +    if aligned_scale.ndim == 3: | 
|  | 72 | +        aligned_scale = aligned_scale.squeeze(-1) | 
|  | 73 | + | 
|  | 74 | +    return float_values, expand_zero_point, aligned_scale | 
|  | 75 | + | 
|  | 76 | + | 
|  | 77 | +class QuantLinearORT(nn.Module): | 
|  | 78 | +    def __init__(self, bits, group_size, in_features, out_features, bias): | 
|  | 79 | +        super().__init__() | 
|  | 80 | +        if bits not in [2, 3, 4, 5, 6, 7, 8]: | 
|  | 81 | +            raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.") | 
|  | 82 | +        self.in_features = in_features | 
|  | 83 | +        self.out_features = out_features | 
|  | 84 | +        self.bits = bits | 
|  | 85 | +        self.group_size = group_size if group_size != -1 else in_features | 
|  | 86 | +        self.act_order = None | 
|  | 87 | + | 
|  | 88 | +        q_rows = in_features // self.group_size | 
|  | 89 | +        self.register_buffer( | 
|  | 90 | +            "qweight", | 
|  | 91 | +            torch.zeros((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8), | 
|  | 92 | +        ) | 
|  | 93 | +        self.register_buffer( | 
|  | 94 | +            "qzeros", | 
|  | 95 | +            torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8), | 
|  | 96 | +        ) | 
|  | 97 | +        self.register_buffer( | 
|  | 98 | +            "scales", torch.zeros((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16) | 
|  | 99 | +        ) | 
|  | 100 | +        self.register_buffer( | 
|  | 101 | +            "g_idx", torch.tensor([i // self.group_size for i in range(in_features)], dtype=torch.int32) | 
|  | 102 | +        ) | 
|  | 103 | +        if bias: | 
|  | 104 | +            self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16)) | 
|  | 105 | +        else: | 
|  | 106 | +            self.bias = None | 
|  | 107 | + | 
|  | 108 | +    def quant_weight(self, weight, scales, zeros, g_idx): | 
|  | 109 | +        scale_zeros = zeros * scales | 
|  | 110 | +        scale_mat = scales[g_idx] | 
|  | 111 | +        scale_zeros_mat = scale_zeros[g_idx] | 
|  | 112 | +        int_weight_T = torch.round(((weight + scale_zeros_mat) / scale_mat).float()).to(torch.int) | 
|  | 113 | +        return int_weight_T | 
|  | 114 | + | 
|  | 115 | +    def pack_on_device(self, int_weight, int_zeros): | 
|  | 116 | +        if self.bits != 4: | 
|  | 117 | +            raise ValueError("only 4bit is supported by ONNXRUNTIME for now.") | 
|  | 118 | + | 
|  | 119 | +        # Order of groups | 
|  | 120 | +        self.act_order = self.g_idx[: self.group_size // self.bits].sum().item() != 0 | 
|  | 121 | + | 
|  | 122 | +        intzeros_pt = int_zeros.T if int_zeros.dtype == self.scales.dtype else int_zeros.T.byte() | 
|  | 123 | +        scales_pt = self.scales.T.to(int_weight.device) | 
|  | 124 | +        intweight_pt = int_weight.byte() | 
|  | 125 | + | 
|  | 126 | +        block_size = self.group_size | 
|  | 127 | +        rows, cols = intweight_pt.shape | 
|  | 128 | +        blob_size = block_size // 2 | 
|  | 129 | +        k_blocks = (rows + block_size - 1) // block_size | 
|  | 130 | +        padded_rows = k_blocks * block_size | 
|  | 131 | +        pad_len = padded_rows - rows | 
|  | 132 | +        if pad_len > 0: | 
|  | 133 | +            intweight_pt = torch.nn.functional.pad(intweight_pt, (0, 0, 0, pad_len), "constant", 0) | 
|  | 134 | +        intzeros_pt = torch.nn.functional.pad(intzeros_pt, (0, intzeros_pt.shape[-1] & 1, 0, 0), "constant", 0) | 
|  | 135 | + | 
|  | 136 | +        # Pack zeros if they are not float | 
|  | 137 | +        if int_zeros.dtype != self.scales.dtype: | 
|  | 138 | +            intzeros_pt = (intzeros_pt[:, 0::2]) | (intzeros_pt[:, 1::2] << 4) | 
|  | 139 | +            intzeros_pt = intzeros_pt.reshape(-1) | 
|  | 140 | + | 
|  | 141 | +        # Pack weights | 
|  | 142 | +        intweight_pt_T = int_weight.T | 
|  | 143 | +        intweight_pt_T = (intweight_pt_T[:, 0::2]) | (intweight_pt_T[:, 1::2] << 4) | 
|  | 144 | +        intweight_pt_T = intweight_pt_T.reshape(cols, k_blocks, blob_size) | 
|  | 145 | + | 
|  | 146 | +        scales_pt = scales_pt.reshape(-1) | 
|  | 147 | + | 
|  | 148 | +        # Validation checks | 
|  | 149 | +        if (self.qweight.shape != intweight_pt_T.shape) and ( | 
|  | 150 | +            self.qzeros.shape == intzeros_pt.shape or self.qzeros.dtype != intzeros_pt.dtype | 
|  | 151 | +        ): | 
|  | 152 | +            raise RuntimeError("Something went wrong while packing the weights in QuantLinearORT module") | 
|  | 153 | + | 
|  | 154 | +        # Assign buffers | 
|  | 155 | +        self.scales = scales_pt.float() | 
|  | 156 | +        self.qweight = intweight_pt_T.byte()  # Convert to uint8 | 
|  | 157 | +        if int_zeros.dtype != self.scales.dtype: | 
|  | 158 | +            self.qzeros = intzeros_pt.byte()  # Convert to uint8 | 
|  | 159 | +        else: | 
|  | 160 | +            self.qzeros = intzeros_pt | 
|  | 161 | + | 
|  | 162 | +    def pack(self, linear, scales, zeros, g_idx=None): | 
|  | 163 | +        layer_weight = linear.weight.data | 
|  | 164 | +        self.scales = scales.T | 
|  | 165 | +        self.g_idx = g_idx.clone() | 
|  | 166 | +        int_weight = self.quant_weight(layer_weight.T, scales.T, zeros.T, g_idx) | 
|  | 167 | +        return self.pack_on_device(int_weight, zeros.T) | 
|  | 168 | + | 
|  | 169 | +    def forward(self, inputs): | 
|  | 170 | +        out = QuantLinearTorchFunction().apply( | 
|  | 171 | +            inputs, | 
|  | 172 | +            self.qweight, | 
|  | 173 | +            self.scales, | 
|  | 174 | +            self.qzeros, | 
|  | 175 | +            self.g_idx if self.act_order else None, | 
|  | 176 | +            self.bits, | 
|  | 177 | +            self.group_size, | 
|  | 178 | +            self.in_features, | 
|  | 179 | +            self.out_features, | 
|  | 180 | +        ) | 
|  | 181 | +        out = out + self.bias if self.bias is not None else out | 
|  | 182 | +        return out | 
0 commit comments