|
| 1 | +// clang-format off |
| 2 | +// Auto-generated file. Do not edit! |
| 3 | +// Template: src/qs8-gemm/c4-hvx.c.in |
| 4 | +// Generator: tools/xngen |
| 5 | +// |
| 6 | +// Copyright 2025 Google LLC |
| 7 | +// |
| 8 | +// This source code is licensed under the BSD-style license found in the |
| 9 | +// LICENSE file in the root directory of this source tree. |
| 10 | + |
| 11 | +#include <assert.h> |
| 12 | +#include <math.h> // for lrintf |
| 13 | +#include <stdio.h> // for printf |
| 14 | + |
| 15 | +#include <hexagon_types.h> |
| 16 | +#include <hexagon_protos.h> |
| 17 | +#include <hvx_hexagon_protos.h> |
| 18 | + |
| 19 | +#include "src/xnnpack/gemm.h" |
| 20 | +#include "src/xnnpack/intrinsics-polyfill.h" // for Q6_V_vstu_variable |
| 21 | +#include "src/xnnpack/math.h" |
| 22 | +#include "src/xnnpack/unaligned.h" |
| 23 | + |
| 24 | + |
| 25 | +// multiply vacc by vscale and return result as int |
| 26 | +// vacc is vector of int32 |
| 27 | +// vscale is vector of floats |
| 28 | +// return is vector of int |
| 29 | +#if __HVX_ARCH__ >= 73 |
| 30 | +static XNN_INLINE HVX_Vector rescale_fp32(HVX_Vector vacc, HVX_Vector vscale) |
| 31 | +{ |
| 32 | + const HVX_Vector vaccf = Q6_Vsf_equals_Vw(vacc); |
| 33 | + const HVX_Vector vscaledqf = Q6_Vqf32_vmpy_VsfVsf(vaccf, vscale); |
| 34 | + |
| 35 | + // Create a vector of `0.5f` with the same sign as the entries of `a`. |
| 36 | + const HVX_Vector vhalf = Q6_V_vsplat_R(float_as_uint32(0.5f)); |
| 37 | + const HVX_Vector vsign_mask = Q6_V_vsplat_R(0x80000000); |
| 38 | + const HVX_Vector vsigned_half = Q6_V_vor_VV(Q6_V_vand_VV(vaccf, vsign_mask), vhalf); |
| 39 | + const HVX_Vector vresult = Q6_Vw_equals_Vsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(vscaledqf, vsigned_half))); |
| 40 | + return vresult; |
| 41 | +} |
| 42 | +#else |
| 43 | +static HVX_Vector rescale_fp32(HVX_Vector vacc, HVX_Vector vscale) |
| 44 | +{ |
| 45 | + XNN_ALIGN(128) int32_t vacc_buffer[32]; |
| 46 | + XNN_ALIGN(128) float vscale_buffer[32]; |
| 47 | + |
| 48 | + *((HVX_Vector *)&vacc_buffer) = vacc; |
| 49 | + *((HVX_Vector *)&vscale_buffer) = vscale; |
| 50 | + |
| 51 | + for (int i = 0; i < 32; ++i) { |
| 52 | + vacc_buffer[i] = (int32_t)lrintf((float)vacc_buffer[i] * vscale_buffer[i]); |
| 53 | + } |
| 54 | + return *(HVX_Vector *)&vacc_buffer; |
| 55 | +} |
| 56 | +#endif // __HVX_ARCH__ >= 73 |
| 57 | + |
| 58 | +void xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_1x128c4__hvx( |
| 59 | + size_t mr, |
| 60 | + size_t nc, |
| 61 | + size_t kc, |
| 62 | + const int8_t* restrict a, |
| 63 | + size_t a_stride, |
| 64 | + const void* restrict w, |
| 65 | + int8_t* restrict c, |
| 66 | + size_t cm_stride, |
| 67 | + size_t cn_stride, |
| 68 | + const union xnn_qs8_qc8w_conv_minmax_params* restrict params) XNN_OOB_READS |
| 69 | +{ |
| 70 | + assert(mr != 0); |
| 71 | + assert(mr <= 1); |
| 72 | + assert(nc != 0); |
| 73 | + assert(kc != 0); |
| 74 | + assert(kc % sizeof(int8_t) == 0); |
| 75 | + assert(a != NULL); |
| 76 | + assert(w != NULL); |
| 77 | + assert(c != NULL); |
| 78 | + |
| 79 | + kc = round_up_po2(kc, 4 * sizeof(int8_t)); |
| 80 | + const int8_t* a0 = a; |
| 81 | + int8_t* c0 = c; |
| 82 | + |
| 83 | + // TODO: Use log when fixed |
| 84 | + { |
| 85 | + static int warning_unaligned = 0; |
| 86 | + if ((a_stride & (sizeof(int32_t) - 1)) != 0 && warning_unaligned == 0) { |
| 87 | + printf("HEXAGON GEMM a_stride unaligned."); |
| 88 | + warning_unaligned = 1; |
| 89 | + } |
| 90 | + static int warning_a_unaligned = 0; |
| 91 | + if ((((intptr_t) a) & (sizeof(int32_t) - 1)) != 0 && warning_a_unaligned == 0) { |
| 92 | + printf("HEXAGON GEMM a unaligned."); |
| 93 | + warning_a_unaligned = 1; |
| 94 | + } |
| 95 | + fflush(stdout); |
| 96 | + } |
| 97 | + |
| 98 | + const HVX_Vector voutput_zero_point = Q6_Vh_vsplat_R(params->fp32_scalar.output_zero_point); |
| 99 | + const HVX_Vector voutput_min = Q6_Vb_vsplat_R(params->fp32_scalar.output_min); |
| 100 | + const HVX_Vector voutput_max = Q6_Vb_vsplat_R(params->fp32_scalar.output_max); |
| 101 | + const HVX_Vector vmask = Q6_Vb_vsplat_R(0xF0); |
| 102 | + do { |
| 103 | + HVX_Vector vacc0x0 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 104 | + HVX_Vector vacc0x1 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 105 | + HVX_Vector vacc0x2 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 106 | + HVX_Vector vacc0x3 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 107 | + |
| 108 | + size_t k = kc; |
| 109 | + if (((((intptr_t) a) | a_stride) & (sizeof(int32_t) - 1)) != 0) { |
| 110 | + for (; k >= 8 * sizeof(int8_t); k -= 8 * sizeof(int8_t)) { |
| 111 | + const HVX_Vector va0x0123 = Q6_V_vsplat_R(unaligned_load_s32(a0)); |
| 112 | + const HVX_Vector va0x4567 = Q6_V_vsplat_R(unaligned_load_s32(a0+4)); a0 += 8; |
| 113 | + |
| 114 | + const HVX_Vector vb0x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 115 | + const HVX_Vector vbs0x0123 = Q6_Vw_vasl_VwR(vb0x01234567, 4); |
| 116 | + const HVX_Vector vb0x4567 = Q6_V_vand_VV(vb0x01234567, vmask); |
| 117 | + const HVX_Vector vb0x0123 = Q6_V_vand_VV(vbs0x0123, vmask); |
| 118 | + const HVX_Vector vb1x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 119 | + const HVX_Vector vbs1x0123 = Q6_Vw_vasl_VwR(vb1x01234567, 4); |
| 120 | + const HVX_Vector vb1x4567 = Q6_V_vand_VV(vb1x01234567, vmask); |
| 121 | + const HVX_Vector vb1x0123 = Q6_V_vand_VV(vbs1x0123, vmask); |
| 122 | + const HVX_Vector vb2x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 123 | + const HVX_Vector vbs2x0123 = Q6_Vw_vasl_VwR(vb2x01234567, 4); |
| 124 | + const HVX_Vector vb2x4567 = Q6_V_vand_VV(vb2x01234567, vmask); |
| 125 | + const HVX_Vector vb2x0123 = Q6_V_vand_VV(vbs2x0123, vmask); |
| 126 | + const HVX_Vector vb3x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 127 | + const HVX_Vector vbs3x0123 = Q6_Vw_vasl_VwR(vb3x01234567, 4); |
| 128 | + const HVX_Vector vb3x4567 = Q6_V_vand_VV(vb3x01234567, vmask); |
| 129 | + const HVX_Vector vb3x0123 = Q6_V_vand_VV(vbs3x0123, vmask); |
| 130 | + |
| 131 | + vacc0x0 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x0, va0x0123, vb0x0123); |
| 132 | + vacc0x1 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x1, va0x0123, vb1x0123); |
| 133 | + vacc0x2 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x2, va0x0123, vb2x0123); |
| 134 | + vacc0x3 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x3, va0x0123, vb3x0123); |
| 135 | + vacc0x0 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x0, va0x4567, vb0x4567); |
| 136 | + vacc0x1 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x1, va0x4567, vb1x4567); |
| 137 | + vacc0x2 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x2, va0x4567, vb2x4567); |
| 138 | + vacc0x3 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x3, va0x4567, vb3x4567); |
| 139 | + } |
| 140 | + } else { |
| 141 | + for (; k >= 8 * sizeof(int8_t); k -= 8 * sizeof(int8_t)) { |
| 142 | + const HVX_Vector va0x0123 = Q6_V_vsplat_R(*((const int32_t*)a0)); |
| 143 | + const HVX_Vector va0x4567 = Q6_V_vsplat_R(*((const int32_t*)a0+4)); a0 += 8; |
| 144 | + |
| 145 | + const HVX_Vector vb0x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 146 | + const HVX_Vector vbs0x0123 = Q6_Vw_vasl_VwR(vb0x01234567, 4); |
| 147 | + const HVX_Vector vb0x4567 = Q6_V_vand_VV(vb0x01234567, vmask); |
| 148 | + const HVX_Vector vb0x0123 = Q6_V_vand_VV(vbs0x0123, vmask); |
| 149 | + const HVX_Vector vb1x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 150 | + const HVX_Vector vbs1x0123 = Q6_Vw_vasl_VwR(vb1x01234567, 4); |
| 151 | + const HVX_Vector vb1x4567 = Q6_V_vand_VV(vb1x01234567, vmask); |
| 152 | + const HVX_Vector vb1x0123 = Q6_V_vand_VV(vbs1x0123, vmask); |
| 153 | + const HVX_Vector vb2x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 154 | + const HVX_Vector vbs2x0123 = Q6_Vw_vasl_VwR(vb2x01234567, 4); |
| 155 | + const HVX_Vector vb2x4567 = Q6_V_vand_VV(vb2x01234567, vmask); |
| 156 | + const HVX_Vector vb2x0123 = Q6_V_vand_VV(vbs2x0123, vmask); |
| 157 | + const HVX_Vector vb3x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 158 | + const HVX_Vector vbs3x0123 = Q6_Vw_vasl_VwR(vb3x01234567, 4); |
| 159 | + const HVX_Vector vb3x4567 = Q6_V_vand_VV(vb3x01234567, vmask); |
| 160 | + const HVX_Vector vb3x0123 = Q6_V_vand_VV(vbs3x0123, vmask); |
| 161 | + |
| 162 | + vacc0x0 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x0, va0x0123, vb0x0123); |
| 163 | + vacc0x1 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x1, va0x0123, vb1x0123); |
| 164 | + vacc0x2 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x2, va0x0123, vb2x0123); |
| 165 | + vacc0x3 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x3, va0x0123, vb3x0123); |
| 166 | + vacc0x0 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x0, va0x4567, vb0x4567); |
| 167 | + vacc0x1 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x1, va0x4567, vb1x4567); |
| 168 | + vacc0x2 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x2, va0x4567, vb2x4567); |
| 169 | + vacc0x3 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x3, va0x4567, vb3x4567); |
| 170 | + } |
| 171 | + } |
| 172 | + if (k != 0) { |
| 173 | + const HVX_Vector va0x0123 = Q6_V_vsplat_R(unaligned_load_s32(a0)); a0 += 4; |
| 174 | + |
| 175 | + const HVX_Vector vb0x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 176 | + const HVX_Vector vbs0x0123 = Q6_Vw_vasl_VwR(vb0x01234567, 4); |
| 177 | + const HVX_Vector vb0x0123 = Q6_V_vand_VV(vbs0x0123, vmask); |
| 178 | + const HVX_Vector vb1x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 179 | + const HVX_Vector vbs1x0123 = Q6_Vw_vasl_VwR(vb1x01234567, 4); |
| 180 | + const HVX_Vector vb1x0123 = Q6_V_vand_VV(vbs1x0123, vmask); |
| 181 | + const HVX_Vector vb2x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 182 | + const HVX_Vector vbs2x0123 = Q6_Vw_vasl_VwR(vb2x01234567, 4); |
| 183 | + const HVX_Vector vb2x0123 = Q6_V_vand_VV(vbs2x0123, vmask); |
| 184 | + const HVX_Vector vb3x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 185 | + const HVX_Vector vbs3x0123 = Q6_Vw_vasl_VwR(vb3x01234567, 4); |
| 186 | + const HVX_Vector vb3x0123 = Q6_V_vand_VV(vbs3x0123, vmask); |
| 187 | + |
| 188 | + vacc0x0 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x0, va0x0123, vb0x0123); |
| 189 | + vacc0x1 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x1, va0x0123, vb1x0123); |
| 190 | + vacc0x2 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x2, va0x0123, vb2x0123); |
| 191 | + vacc0x3 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x3, va0x0123, vb3x0123); |
| 192 | + } |
| 193 | + |
| 194 | + vacc0x0 = Q6_Vw_vasr_VwR(vacc0x0, 4); |
| 195 | + vacc0x1 = Q6_Vw_vasr_VwR(vacc0x1, 4); |
| 196 | + vacc0x2 = Q6_Vw_vasr_VwR(vacc0x2, 4); |
| 197 | + vacc0x3 = Q6_Vw_vasr_VwR(vacc0x3, 4); |
| 198 | + |
| 199 | + const HVX_Vector vscale0 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 200 | + vacc0x0 = rescale_fp32(vacc0x0, vscale0); |
| 201 | + const HVX_Vector vscale1 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 202 | + vacc0x1 = rescale_fp32(vacc0x1, vscale1); |
| 203 | + const HVX_Vector vscale2 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 204 | + vacc0x2 = rescale_fp32(vacc0x2, vscale2); |
| 205 | + const HVX_Vector vscale3 = *((HVX_Vector *) w); w = (const int8_t*) w + 128; |
| 206 | + vacc0x3 = rescale_fp32(vacc0x3, vscale3); |
| 207 | + |
| 208 | + HVX_Vector vout0x0 = Q6_Vh_vpack_VwVw_sat(vacc0x1, vacc0x0); |
| 209 | + HVX_Vector vout0x1 = Q6_Vh_vpack_VwVw_sat(vacc0x3, vacc0x2); |
| 210 | + |
| 211 | + vout0x0 = Q6_Vh_vadd_VhVh_sat(vout0x0, voutput_zero_point); |
| 212 | + vout0x1 = Q6_Vh_vadd_VhVh_sat(vout0x1, voutput_zero_point); |
| 213 | + |
| 214 | + HVX_Vector vout0 = Q6_Vb_vpack_VhVh_sat(vout0x1, vout0x0); |
| 215 | + |
| 216 | + vout0 = Q6_Vb_vmax_VbVb(vout0, voutput_min); |
| 217 | + |
| 218 | + vout0 = Q6_Vb_vmin_VbVb(vout0, voutput_max); |
| 219 | + |
| 220 | + if XNN_LIKELY(nc >= 128) { |
| 221 | + *((HVX_UVector *)c0) = vout0; |
| 222 | + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); |
| 223 | + a0 = (const int8_t*) ((uintptr_t) a0 - kc); |
| 224 | + |
| 225 | + nc -= 128; |
| 226 | + } else { |
| 227 | + // Prepare mask for valid 8-bit elements (depends on nc). |
| 228 | + Q6_V_vstu_variable(c0, nc, vout0); |
| 229 | + nc = 0; |
| 230 | + } |
| 231 | + } while (nc != 0); |
| 232 | +} |
0 commit comments