Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions bench/qs8-qc4w-gemm-fp32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@



#if XNN_ENABLE_HVX && XNN_ARCH_HEXAGON
static void qs8_qc4w_gemm_minmax_fp32_ukernel_1x128c4__hvx(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_1x128c4__hvx,
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w,
/*mr=*/1, /*nr=*/128, /*kr=*/4, /*sr=*/1,
/*arch_flags=*/xnn_arch_hvx);
}

BENCHMARK_GEMM(qs8_qc4w_gemm_minmax_fp32_ukernel_1x128c4__hvx)
#endif // XNN_ENABLE_HVX && XNN_ARCH_HEXAGON


#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY
static void qs8_qc4w_gemm_minmax_fp32_ukernel_1x16c4__asm_aarch64_neondot_ld32_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
Expand Down
8 changes: 8 additions & 0 deletions cmake/gen/hvx_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ SET(NON_PROD_HVX_MICROKERNEL_SRCS
src/f32-vtanh/gen/f32-vtanh-hvx-rational-9-8-nr.c
src/qs8-packw/gen/qs8-packw-x96c4-gemm-gio-hvx.c
src/qs8-packw/gen/qs8-packw-x96c4-gemm-goi-hvx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x128c4-minmax-fp32-hvx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x128c4-minmax-fp32-hvx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x128c4-minmax-fp32-hvx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x128c4-minmax-fp32-hvx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-5x128c4-minmax-fp32-hvx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-6x128c4-minmax-fp32-hvx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x128c4-minmax-fp32-hvx.c
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-8x128c4-minmax-fp32-hvx.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x32c4-minmax-fp32-hvx.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x64c4-minmax-fp32-hvx.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x128c4-minmax-fp32-hvx-prfm.c
Expand Down
8 changes: 8 additions & 0 deletions gen/hvx_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,14 @@ NON_PROD_HVX_MICROKERNEL_SRCS = [
"src/f32-vtanh/gen/f32-vtanh-hvx-rational-9-8-nr.c",
"src/qs8-packw/gen/qs8-packw-x96c4-gemm-gio-hvx.c",
"src/qs8-packw/gen/qs8-packw-x96c4-gemm-goi-hvx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x128c4-minmax-fp32-hvx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x128c4-minmax-fp32-hvx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x128c4-minmax-fp32-hvx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x128c4-minmax-fp32-hvx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-5x128c4-minmax-fp32-hvx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-6x128c4-minmax-fp32-hvx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x128c4-minmax-fp32-hvx.c",
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-8x128c4-minmax-fp32-hvx.c",
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x32c4-minmax-fp32-hvx.c",
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x64c4-minmax-fp32-hvx.c",
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x128c4-minmax-fp32-hvx-prfm.c",
Expand Down
9 changes: 9 additions & 0 deletions scripts/generate-qs8-gemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2083,4 +2083,13 @@ tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=6 -D NR=128 -D DATATYPE=QC8 -D PREFE
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=7 -D NR=128 -D DATATYPE=QC8 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x128c4-minmax-fp32-hvx-prfm.c &
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=8 -D NR=128 -D DATATYPE=QC8 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x128c4-minmax-fp32-hvx-prfm.c &

tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=1 -D NR=128 -D DATATYPE=QS8_QC4 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x128c4-minmax-fp32-hvx.c &
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=2 -D NR=128 -D DATATYPE=QS8_QC4 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x128c4-minmax-fp32-hvx.c &
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=3 -D NR=128 -D DATATYPE=QS8_QC4 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x128c4-minmax-fp32-hvx.c &
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=4 -D NR=128 -D DATATYPE=QS8_QC4 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x128c4-minmax-fp32-hvx.c &
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=5 -D NR=128 -D DATATYPE=QS8_QC4 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-5x128c4-minmax-fp32-hvx.c &
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=6 -D NR=128 -D DATATYPE=QS8_QC4 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-6x128c4-minmax-fp32-hvx.c &
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=7 -D NR=128 -D DATATYPE=QS8_QC4 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x128c4-minmax-fp32-hvx.c &
tools/xngen src/qs8-gemm/c4-hvx.c.in -D MR=8 -D NR=128 -D DATATYPE=QS8_QC4 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-8x128c4-minmax-fp32-hvx.c &

wait
16 changes: 12 additions & 4 deletions src/qs8-gemm/c4-hvx.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,11 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c
$if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]:
$for N in range(0, NR, 32):
const HVX_Vector vb${N//32}x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
$for N in range(0, NR, 32):
const HVX_Vector vbs${N//32}x0123 = Q6_Vw_vasl_VwR(vb${N//32}x01234567, 4);
$for N in range(0, NR, 32):
const HVX_Vector vb${N//32}x4567 = Q6_V_vand_VV(vb${N//32}x01234567, vmask);
$for N in range(0, NR, 32):
const HVX_Vector vb${N//32}x0123 = Q6_V_vand_VV(vbs${N//32}x0123, vmask);
$else:
$for N in range(0, NR, 32):
Expand All @@ -179,8 +182,11 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c
$if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]:
$for N in range(0, NR, 32):
const HVX_Vector vb${N//32}x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
$for N in range(0, NR, 32):
const HVX_Vector vbs${N//32}x0123 = Q6_Vw_vasl_VwR(vb${N//32}x01234567, 4);
$for N in range(0, NR, 32):
const HVX_Vector vb${N//32}x4567 = Q6_V_vand_VV(vb${N//32}x01234567, vmask);
$for N in range(0, NR, 32):
const HVX_Vector vb${N//32}x0123 = Q6_V_vand_VV(vbs${N//32}x0123, vmask);
$else:
$for N in range(0, NR, 32):
Expand All @@ -201,10 +207,12 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c
const HVX_Vector va${M}x0123 = Q6_V_vsplat_R(unaligned_load_s32(a${M})); a${M} += 4;

$if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]:
$for N in range(0, NR, 32):
const HVX_Vector vb${N//32}x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vbs${N//32}x0123 = Q6_Vw_vasl_VwR(vb${N//32}x01234567, 4);
const HVX_Vector vb${N//32}x0123 = Q6_V_vand_VV(vbs${N//32}x0123, vmask);
$for N in range(0, NR, 32):
const HVX_Vector vb${N//32}x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
$for N in range(0, NR, 32):
const HVX_Vector vbs${N//32}x0123 = Q6_Vw_vasl_VwR(vb${N//32}x01234567, 4);
$for N in range(0, NR, 32):
const HVX_Vector vb${N//32}x0123 = Q6_V_vand_VV(vbs${N//32}x0123, vmask);
$else:
$for N in range(0, NR, 32):
const HVX_Vector vb${N//32}x0123 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
Expand Down
232 changes: 232 additions & 0 deletions src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x128c4-minmax-fp32-hvx.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
// clang-format off
// Auto-generated file. Do not edit!
// Template: src/qs8-gemm/c4-hvx.c.in
// Generator: tools/xngen
//
// Copyright 2025 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h> // for lrintf
#include <stdio.h> // for printf

#include <hexagon_types.h>
#include <hexagon_protos.h>
#include <hvx_hexagon_protos.h>

#include "src/xnnpack/gemm.h"
#include "src/xnnpack/intrinsics-polyfill.h" // for Q6_V_vstu_variable
#include "src/xnnpack/math.h"
#include "src/xnnpack/unaligned.h"


// multiply vacc by vscale and return result as int
// vacc is vector of int32
// vscale is vector of floats
// return is vector of int
#if __HVX_ARCH__ >= 73
static XNN_INLINE HVX_Vector rescale_fp32(HVX_Vector vacc, HVX_Vector vscale)
{
const HVX_Vector vaccf = Q6_Vsf_equals_Vw(vacc);
const HVX_Vector vscaledqf = Q6_Vqf32_vmpy_VsfVsf(vaccf, vscale);

// Create a vector of `0.5f` with the same sign as the entries of `a`.
const HVX_Vector vhalf = Q6_V_vsplat_R(float_as_uint32(0.5f));
const HVX_Vector vsign_mask = Q6_V_vsplat_R(0x80000000);
const HVX_Vector vsigned_half = Q6_V_vor_VV(Q6_V_vand_VV(vaccf, vsign_mask), vhalf);
const HVX_Vector vresult = Q6_Vw_equals_Vsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(vscaledqf, vsigned_half)));
return vresult;
}
#else
static HVX_Vector rescale_fp32(HVX_Vector vacc, HVX_Vector vscale)
{
XNN_ALIGN(128) int32_t vacc_buffer[32];
XNN_ALIGN(128) float vscale_buffer[32];

*((HVX_Vector *)&vacc_buffer) = vacc;
*((HVX_Vector *)&vscale_buffer) = vscale;

for (int i = 0; i < 32; ++i) {
vacc_buffer[i] = (int32_t)lrintf((float)vacc_buffer[i] * vscale_buffer[i]);
}
return *(HVX_Vector *)&vacc_buffer;
}
#endif // __HVX_ARCH__ >= 73

void xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_1x128c4__hvx(
size_t mr,
size_t nc,
size_t kc,
const int8_t* restrict a,
size_t a_stride,
const void* restrict w,
int8_t* restrict c,
size_t cm_stride,
size_t cn_stride,
const union xnn_qs8_qc8w_conv_minmax_params* restrict params) XNN_OOB_READS
{
assert(mr != 0);
assert(mr <= 1);
assert(nc != 0);
assert(kc != 0);
assert(kc % sizeof(int8_t) == 0);
assert(a != NULL);
assert(w != NULL);
assert(c != NULL);

kc = round_up_po2(kc, 4 * sizeof(int8_t));
const int8_t* a0 = a;
int8_t* c0 = c;

// TODO: Use log when fixed
{
static int warning_unaligned = 0;
if ((a_stride & (sizeof(int32_t) - 1)) != 0 && warning_unaligned == 0) {
printf("HEXAGON GEMM a_stride unaligned.");
warning_unaligned = 1;
}
static int warning_a_unaligned = 0;
if ((((intptr_t) a) & (sizeof(int32_t) - 1)) != 0 && warning_a_unaligned == 0) {
printf("HEXAGON GEMM a unaligned.");
warning_a_unaligned = 1;
}
fflush(stdout);
}

const HVX_Vector voutput_zero_point = Q6_Vh_vsplat_R(params->fp32_scalar.output_zero_point);
const HVX_Vector voutput_min = Q6_Vb_vsplat_R(params->fp32_scalar.output_min);
const HVX_Vector voutput_max = Q6_Vb_vsplat_R(params->fp32_scalar.output_max);
const HVX_Vector vmask = Q6_Vb_vsplat_R(0xF0);
do {
HVX_Vector vacc0x0 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
HVX_Vector vacc0x1 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
HVX_Vector vacc0x2 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
HVX_Vector vacc0x3 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;

size_t k = kc;
if (((((intptr_t) a) | a_stride) & (sizeof(int32_t) - 1)) != 0) {
for (; k >= 8 * sizeof(int8_t); k -= 8 * sizeof(int8_t)) {
const HVX_Vector va0x0123 = Q6_V_vsplat_R(unaligned_load_s32(a0));
const HVX_Vector va0x4567 = Q6_V_vsplat_R(unaligned_load_s32(a0+4)); a0 += 8;

const HVX_Vector vb0x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vb1x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vb2x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vb3x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vbs0x0123 = Q6_Vw_vasl_VwR(vb0x01234567, 4);
const HVX_Vector vbs1x0123 = Q6_Vw_vasl_VwR(vb1x01234567, 4);
const HVX_Vector vbs2x0123 = Q6_Vw_vasl_VwR(vb2x01234567, 4);
const HVX_Vector vbs3x0123 = Q6_Vw_vasl_VwR(vb3x01234567, 4);
const HVX_Vector vb0x4567 = Q6_V_vand_VV(vb0x01234567, vmask);
const HVX_Vector vb1x4567 = Q6_V_vand_VV(vb1x01234567, vmask);
const HVX_Vector vb2x4567 = Q6_V_vand_VV(vb2x01234567, vmask);
const HVX_Vector vb3x4567 = Q6_V_vand_VV(vb3x01234567, vmask);
const HVX_Vector vb0x0123 = Q6_V_vand_VV(vbs0x0123, vmask);
const HVX_Vector vb1x0123 = Q6_V_vand_VV(vbs1x0123, vmask);
const HVX_Vector vb2x0123 = Q6_V_vand_VV(vbs2x0123, vmask);
const HVX_Vector vb3x0123 = Q6_V_vand_VV(vbs3x0123, vmask);

vacc0x0 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x0, va0x0123, vb0x0123);
vacc0x1 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x1, va0x0123, vb1x0123);
vacc0x2 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x2, va0x0123, vb2x0123);
vacc0x3 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x3, va0x0123, vb3x0123);
vacc0x0 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x0, va0x4567, vb0x4567);
vacc0x1 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x1, va0x4567, vb1x4567);
vacc0x2 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x2, va0x4567, vb2x4567);
vacc0x3 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x3, va0x4567, vb3x4567);
}
} else {
for (; k >= 8 * sizeof(int8_t); k -= 8 * sizeof(int8_t)) {
const HVX_Vector va0x0123 = Q6_V_vsplat_R(*((const int32_t*)a0));
const HVX_Vector va0x4567 = Q6_V_vsplat_R(*((const int32_t*)a0+4)); a0 += 8;

const HVX_Vector vb0x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vb1x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vb2x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vb3x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vbs0x0123 = Q6_Vw_vasl_VwR(vb0x01234567, 4);
const HVX_Vector vbs1x0123 = Q6_Vw_vasl_VwR(vb1x01234567, 4);
const HVX_Vector vbs2x0123 = Q6_Vw_vasl_VwR(vb2x01234567, 4);
const HVX_Vector vbs3x0123 = Q6_Vw_vasl_VwR(vb3x01234567, 4);
const HVX_Vector vb0x4567 = Q6_V_vand_VV(vb0x01234567, vmask);
const HVX_Vector vb1x4567 = Q6_V_vand_VV(vb1x01234567, vmask);
const HVX_Vector vb2x4567 = Q6_V_vand_VV(vb2x01234567, vmask);
const HVX_Vector vb3x4567 = Q6_V_vand_VV(vb3x01234567, vmask);
const HVX_Vector vb0x0123 = Q6_V_vand_VV(vbs0x0123, vmask);
const HVX_Vector vb1x0123 = Q6_V_vand_VV(vbs1x0123, vmask);
const HVX_Vector vb2x0123 = Q6_V_vand_VV(vbs2x0123, vmask);
const HVX_Vector vb3x0123 = Q6_V_vand_VV(vbs3x0123, vmask);

vacc0x0 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x0, va0x0123, vb0x0123);
vacc0x1 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x1, va0x0123, vb1x0123);
vacc0x2 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x2, va0x0123, vb2x0123);
vacc0x3 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x3, va0x0123, vb3x0123);
vacc0x0 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x0, va0x4567, vb0x4567);
vacc0x1 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x1, va0x4567, vb1x4567);
vacc0x2 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x2, va0x4567, vb2x4567);
vacc0x3 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x3, va0x4567, vb3x4567);
}
}
if (k != 0) {
const HVX_Vector va0x0123 = Q6_V_vsplat_R(unaligned_load_s32(a0)); a0 += 4;

const HVX_Vector vb0x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vb1x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vb2x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vb3x01234567 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
const HVX_Vector vbs0x0123 = Q6_Vw_vasl_VwR(vb0x01234567, 4);
const HVX_Vector vbs1x0123 = Q6_Vw_vasl_VwR(vb1x01234567, 4);
const HVX_Vector vbs2x0123 = Q6_Vw_vasl_VwR(vb2x01234567, 4);
const HVX_Vector vbs3x0123 = Q6_Vw_vasl_VwR(vb3x01234567, 4);
const HVX_Vector vb0x0123 = Q6_V_vand_VV(vbs0x0123, vmask);
const HVX_Vector vb1x0123 = Q6_V_vand_VV(vbs1x0123, vmask);
const HVX_Vector vb2x0123 = Q6_V_vand_VV(vbs2x0123, vmask);
const HVX_Vector vb3x0123 = Q6_V_vand_VV(vbs3x0123, vmask);

vacc0x0 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x0, va0x0123, vb0x0123);
vacc0x1 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x1, va0x0123, vb1x0123);
vacc0x2 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x2, va0x0123, vb2x0123);
vacc0x3 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x3, va0x0123, vb3x0123);
}

vacc0x0 = Q6_Vw_vasr_VwR(vacc0x0, 4);
vacc0x1 = Q6_Vw_vasr_VwR(vacc0x1, 4);
vacc0x2 = Q6_Vw_vasr_VwR(vacc0x2, 4);
vacc0x3 = Q6_Vw_vasr_VwR(vacc0x3, 4);

const HVX_Vector vscale0 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
vacc0x0 = rescale_fp32(vacc0x0, vscale0);
const HVX_Vector vscale1 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
vacc0x1 = rescale_fp32(vacc0x1, vscale1);
const HVX_Vector vscale2 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
vacc0x2 = rescale_fp32(vacc0x2, vscale2);
const HVX_Vector vscale3 = *((HVX_Vector *) w); w = (const int8_t*) w + 128;
vacc0x3 = rescale_fp32(vacc0x3, vscale3);

HVX_Vector vout0x0 = Q6_Vh_vpack_VwVw_sat(vacc0x1, vacc0x0);
HVX_Vector vout0x1 = Q6_Vh_vpack_VwVw_sat(vacc0x3, vacc0x2);

vout0x0 = Q6_Vh_vadd_VhVh_sat(vout0x0, voutput_zero_point);
vout0x1 = Q6_Vh_vadd_VhVh_sat(vout0x1, voutput_zero_point);

HVX_Vector vout0 = Q6_Vb_vpack_VhVh_sat(vout0x1, vout0x0);

vout0 = Q6_Vb_vmax_VbVb(vout0, voutput_min);

vout0 = Q6_Vb_vmin_VbVb(vout0, voutput_max);

if XNN_LIKELY(nc >= 128) {
*((HVX_UVector *)c0) = vout0;
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
a0 = (const int8_t*) ((uintptr_t) a0 - kc);

nc -= 128;
} else {
// Prepare mask for valid 8-bit elements (depends on nc).
Q6_V_vstu_variable(c0, nc, vout0);
nc = 0;
}
} while (nc != 0);
}
Loading
Loading