diff --git a/CMakeLists.txt b/CMakeLists.txt index 64a9d9cb3d9..dbffa4cbc40 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -501,6 +501,7 @@ SET(OPERATOR_SRCS src/operators/dynamic-fully-connected-nc.c src/operators/fully-connected-nc.c src/operators/max-pooling-nhwc.c + src/operators/normalize-nc.c src/operators/pack-lh.c src/operators/reduce-nd.c src/operators/resize-bilinear-nchw.c @@ -535,6 +536,7 @@ SET(SUBGRAPH_SRCS src/subgraph/fully-connected-sparse.c src/subgraph/fully-connected.c src/subgraph/max-pooling-2d.c + src/subgraph/normalize.c src/subgraph/pack-lh.c src/subgraph/reshape-helpers.c src/subgraph/rope.c diff --git a/bench/BUILD.bazel b/bench/BUILD.bazel index b6c8127a648..e103e9b5ae5 100644 --- a/bench/BUILD.bazel +++ b/bench/BUILD.bazel @@ -97,6 +97,7 @@ xnnpack_cc_library( xnnpack_cxx_library( name = "gemm_benchmark", + testonly = True, srcs = [ "gemm-benchmark.cc", ], @@ -115,6 +116,7 @@ xnnpack_cxx_library( xnnpack_cxx_library( name = "packw_benchmark", + testonly = True, hdrs = [ "packw-benchmark.h", ], @@ -126,6 +128,7 @@ xnnpack_cxx_library( xnnpack_cxx_library( name = "bgemm", + testonly = True, hdrs = [ "bgemm.h", ], @@ -134,6 +137,19 @@ xnnpack_cxx_library( ], ) +xnnpack_cxx_library( + name = "packq_benchmark", + testonly = True, + srcs = [ + "packq-benchmark.cc", + ], + hdrs = ["packq-benchmark.h"], + deps = MICROKERNEL_BENCHMARK_DEPS + [ + ":bgemm", + "@com_google_benchmark//:benchmark", + ], +) + ######################### Benchmarks for micro-kernels ######################### [xnnpack_benchmark( @@ -275,8 +291,10 @@ xnnpack_benchmark( "f32_vcmul", "rdminmax", "rdsum", + "rdsum2", "rminmax", "rsum", + "rsum2", "x8_lut", ]] @@ -453,18 +471,6 @@ xnnpack_benchmark( ], ) -xnnpack_cxx_library( - name = "packq_benchmark", - srcs = [ - "packq-benchmark.cc", - ], - hdrs = ["packq-benchmark.h"], - deps = MICROKERNEL_BENCHMARK_DEPS + [ - ":bgemm", - "@com_google_benchmark//:benchmark", - ], -) - xnnpack_benchmark( name = "x8_packq_bench", srcs = [ diff --git a/bench/rdsum2.cc b/bench/rdsum2.cc new file mode 100644 index 00000000000..d90262b3aa0 --- /dev/null +++ b/bench/rdsum2.cc @@ -0,0 +1,78 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "bench/utils.h" +#include "src/xnnpack/buffer.h" +#include "src/xnnpack/common.h" +#include "src/xnnpack/hardware-config.h" // IWYU pragma: keep +#include "src/xnnpack/reduce.h" // IWYU pragma: keep +#include + +// Microkernel function, templated on the `params` type. +template +using UKernelFn = void (*)(size_t, size_t, size_t, size_t, const Input*, size_t, + size_t, size_t, const Input*, Output*, + const UKernelParams*); + +template +static void reduce(benchmark::State& state, uint64_t arch_flags, + UKernelFn ukernel) { + if (!benchmark::utils::CheckArchFlags(state, arch_flags)) { + return; + } + + const size_t channels = state.range(0); + const size_t rows = state.range(1); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + + xnnpack::Buffer input( + channels * rows, xnnpack::XnnExtraBytes); + xnnpack::Buffer zero(channels, 0, + xnnpack::XnnExtraBytes); + xnnpack::fill_uniform_random_bits(input.data(), input.size(), rng); + xnnpack::Buffer output(channels); + + UKernelParams params; + memset(¶ms, 0, sizeof(params)); + + for (auto _ : state) { + ukernel(channels, rows, 1, 1, input.data(), channels * sizeof(Input), 0, 0, + zero.data(), output.data(), ¶ms); + } + + const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); + if (cpu_frequency != 0) { + state.counters["cpufreq"] = cpu_frequency; + } + + const size_t elements_per_iteration = channels * rows; + state.counters["elements"] = benchmark::Counter( + static_cast(state.iterations()) * elements_per_iteration, + benchmark::Counter::kIsRate); + + const size_t bytes_per_iteration = channels * rows * sizeof(Input); + state.counters["bytes"] = benchmark::Counter( + static_cast(state.iterations()) * bytes_per_iteration, + benchmark::Counter::kIsRate); +} + +#define XNN_UKERNEL(arch_flags, ukernel, row_tile, batch_tile, vector_tile, \ + datatype_in, datatype_out, params_type, init_params) \ + BENCHMARK_CAPTURE(reduce, ukernel, arch_flags, ukernel) \ + ->Apply(benchmark::utils::ReduceDiscontiguousParameters) \ + ->UseRealTime(); +// #include "src/f16-f32acc-rdsum/f16-f32acc-rdsum.inc" +#include "src/f32-rdsum/f32-rdsum.inc" +#undef XNN_UKERNEL + +#ifndef XNNPACK_BENCHMARK_NO_MAIN +XNN_BENCHMARK_MAIN(); +#endif diff --git a/bench/rsum2.cc b/bench/rsum2.cc new file mode 100644 index 00000000000..d40d6ebbd20 --- /dev/null +++ b/bench/rsum2.cc @@ -0,0 +1,76 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "bench/utils.h" +#include "src/xnnpack/buffer.h" +#include "src/xnnpack/common.h" +#include "src/xnnpack/hardware-config.h" // IWYU pragma: keep +#include "src/xnnpack/reduce.h" // IWYU pragma: keep +#include + +// Microkernel function, templated on the `params` type. +template +using UKernelFn = void (*)(size_t, const Input*, Output*, const UKernelParams*); + +template +static void reduce(benchmark::State& state, uint64_t arch_flags, + UKernelFn ukernel) { + if (!benchmark::utils::CheckArchFlags(state, arch_flags)) { + return; + } + + const size_t channels = state.range(0); + const size_t rows = state.range(1); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + + xnnpack::Buffer input( + channels * rows, xnnpack::XnnExtraBytes); + xnnpack::fill_uniform_random_bits(input.data(), input.size(), rng); + + UKernelParams params; + memset(¶ms, 0, sizeof(params)); + + Output output = 0; + for (auto _ : state) { + for (size_t r = 0; r < rows; ++r) { + ukernel(channels * sizeof(Input), input.data() + r * channels, &output, + ¶ms); + } + } + + const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); + if (cpu_frequency != 0) { + state.counters["cpufreq"] = cpu_frequency; + } + + const size_t elements_per_iteration = rows * channels; + state.counters["elements"] = benchmark::Counter( + static_cast(state.iterations()) * elements_per_iteration, + benchmark::Counter::kIsRate); + + const size_t bytes_per_iteration = rows * channels * sizeof(Input); + state.counters["bytes"] = benchmark::Counter( + static_cast(state.iterations()) * bytes_per_iteration, + benchmark::Counter::kIsRate); +} + +#define XNN_UKERNEL(arch_flags, ukernel, batch_tile, vector_tile, datatype_in, \ + datatype_out, params_type, init_params) \ + BENCHMARK_CAPTURE(reduce, ukernel, arch_flags, ukernel) \ + ->Apply(benchmark::utils::ReduceParameters) \ + ->UseRealTime(); +#include "src/f16-f32acc-rsum2/f16-f32acc-rsum2.inc" +#include "src/f32-rsum2/f32-rsum2.inc" +#undef XNN_UKERNEL + +#ifndef XNNPACK_BENCHMARK_NO_MAIN +XNN_BENCHMARK_MAIN(); +#endif diff --git a/bench/subgraph/fp32-l2-norm.cc b/bench/subgraph/fp32-l2-norm.cc index da974879c6f..29847e6dffe 100644 --- a/bench/subgraph/fp32-l2-norm.cc +++ b/bench/subgraph/fp32-l2-norm.cc @@ -93,7 +93,7 @@ xnn_subgraph_t FP32L2Norm(size_t m, size_t n, size_t k, uint32_t norm_mask) { uint32_t inv_sqrt_sum_sq = XNN_INVALID_VALUE_ID; status = xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, dims.size(), dims.data(), + subgraph, xnn_datatype_fp32, reduction_dims.size(), reduction_dims.data(), /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &inv_sqrt_sum_sq); if (status != xnn_status_success) { std::cerr << "failed to create tensor inv_sqrt_sum_sq" << std::endl; diff --git a/bench/utils.h b/bench/utils.h index 3a99e2433f2..999c1caee23 100644 --- a/bench/utils.h +++ b/bench/utils.h @@ -9,7 +9,6 @@ #include #include #include -#include #include #include "src/xnnpack/common.h" diff --git a/build_srcs.bzl b/build_srcs.bzl index ef98269fc28..4b1d8883118 100644 --- a/build_srcs.bzl +++ b/build_srcs.bzl @@ -25,6 +25,7 @@ OPERATOR_SRCS = [ "src/operators/dynamic-fully-connected-nc.c", "src/operators/fully-connected-nc.c", "src/operators/max-pooling-nhwc.c", + "src/operators/normalize-nc.c", "src/operators/pack-lh.c", "src/operators/reduce-nd.c", "src/operators/resize-bilinear-nchw.c", @@ -56,6 +57,7 @@ SUBGRAPH_SRCS = [ "src/subgraph/fully-connected-sparse.c", "src/subgraph/fully-connected.c", "src/subgraph/max-pooling-2d.c", + "src/subgraph/normalize.c", "src/subgraph/pack-lh.c", "src/subgraph/reshape-helpers.c", "src/subgraph/rope.c", @@ -99,7 +101,9 @@ MICROKERNEL_DEFS = [ "src/f16-dwconv/f16-dwconv-minmax.inc", "src/f16-f32-vcvt/f16-f32-vcvt.inc", "src/f16-f32acc-rdsum/f16-f32acc-rdsum.inc", + "src/f16-f32acc-rdsum2/f16-f32acc-rdsum2.inc", "src/f16-f32acc-rsum/f16-f32acc-rsum.inc", + "src/f16-f32acc-rsum2/f16-f32acc-rsum2.inc", "src/f16-maxpool/f16-maxpool-minmax.inc", "src/f16-qs8-vcvt/f16-qs8-vcvt.inc", "src/f16-qu8-vcvt/f16-qu8-vcvt.inc", @@ -163,10 +167,12 @@ MICROKERNEL_DEFS = [ "src/f32-rdminmax/f32-rdmax.inc", "src/f32-rdminmax/f32-rdmin.inc", "src/f32-rdsum/f32-rdsum.inc", + "src/f32-rdsum2/f32-rdsum2.inc", "src/f32-rminmax/f32-rmax.inc", "src/f32-rminmax/f32-rmin.inc", "src/f32-rminmax/f32-rminmax.inc", "src/f32-rsum/f32-rsum.inc", + "src/f32-rsum2/f32-rsum2.inc", "src/f32-spmm/f32-spmm-minmax.inc", "src/f32-vabs/f32-vabs.inc", "src/f32-vapproxgelu/f32-vapproxgelu.inc", diff --git a/cmake/gen/avx512f_microkernels.cmake b/cmake/gen/avx512f_microkernels.cmake index 5ae3376aee0..92991240cce 100644 --- a/cmake/gen/avx512f_microkernels.cmake +++ b/cmake/gen/avx512f_microkernels.cmake @@ -27,10 +27,12 @@ SET(PROD_AVX512F_MICROKERNEL_SRCS src/f32-rdminmax/gen/f32-rdmax-2p2x-avx512f-u32.c src/f32-rdminmax/gen/f32-rdmin-2p2x-avx512f-u32.c src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-u64.c + src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx512f.c src/f32-rminmax/gen/f32-rmax-avx512f-u64-acc4.c src/f32-rminmax/gen/f32-rmin-avx512f-u64-acc4.c src/f32-rminmax/gen/f32-rminmax-avx512f-u64-acc4.c src/f32-rsum/gen/f32-rsum-avx512f-u32-acc2.c + src/f32-rsum2/gen/f32-rsum2-avx512f-u16.c src/f32-vapproxgelu/gen/f32-vapproxgelu-avx512f-rational-12-10-div.c src/f32-vbinary/gen/f32-vadd-avx512f-u32.c src/f32-vbinary/gen/f32-vaddc-avx512f-u32.c diff --git a/cmake/gen/avx512skx_microkernels.cmake b/cmake/gen/avx512skx_microkernels.cmake index bf0d817a31c..7041481c786 100644 --- a/cmake/gen/avx512skx_microkernels.cmake +++ b/cmake/gen/avx512skx_microkernels.cmake @@ -12,7 +12,9 @@ SET(PROD_AVX512SKX_MICROKERNEL_SRCS src/f16-f32-vcvt/gen/f16-f32-vcvt-avx512skx-u16.c src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-avx512skx-u64.c + src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-avx512skx.c src/f16-f32acc-rsum/gen/f16-f32acc-rsum-avx512skx-u32-acc2.c + src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-avx512skx.c src/f16-rminmax/gen/f16-rmax-avx512skx-u64-acc4.c src/f16-rminmax/gen/f16-rmin-avx512skx-u64-acc4.c src/f16-rminmax/gen/f16-rminmax-avx512skx-u64-acc4.c diff --git a/cmake/gen/avx_microkernels.cmake b/cmake/gen/avx_microkernels.cmake index 9961ad9ec1e..f0e136ea714 100644 --- a/cmake/gen/avx_microkernels.cmake +++ b/cmake/gen/avx_microkernels.cmake @@ -34,10 +34,12 @@ SET(PROD_AVX_MICROKERNEL_SRCS src/f32-rdminmax/gen/f32-rdmax-2p2x-avx-u32.c src/f32-rdminmax/gen/f32-rdmin-2p2x-avx-u32.c src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-u32.c + src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx.c src/f32-rminmax/gen/f32-rmax-avx-u32-acc4.c src/f32-rminmax/gen/f32-rmin-avx-u32-acc4.c src/f32-rminmax/gen/f32-rminmax-avx-u32-acc4.c src/f32-rsum/gen/f32-rsum-avx-u32-acc4.c + src/f32-rsum2/gen/f32-rsum2-avx-u8.c src/f32-vapproxgelu/gen/f32-vapproxgelu-avx-rational-12-10-div.c src/f32-vbinary/gen/f32-vadd-avx-u16.c src/f32-vbinary/gen/f32-vaddc-avx-u16.c diff --git a/cmake/gen/f16c_microkernels.cmake b/cmake/gen/f16c_microkernels.cmake index fd72e517422..c98a036ffda 100644 --- a/cmake/gen/f16c_microkernels.cmake +++ b/cmake/gen/f16c_microkernels.cmake @@ -13,7 +13,9 @@ SET(PROD_F16C_MICROKERNEL_SRCS src/f16-avgpool/gen/f16-avgpool-9p-minmax-f16c.c src/f16-f32-vcvt/gen/f16-f32-vcvt-f16c-u16.c src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-f16c-u32.c + src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-f16c.c src/f16-f32acc-rsum/gen/f16-f32acc-rsum-f16c-u32-acc4.c + src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-f16c.c src/f16-rminmax/f16-rmax-f16c-u32.c src/f16-vbinary/gen/f16-vadd-f16c-u16.c src/f16-vbinary/gen/f16-vaddc-f16c-u16.c diff --git a/cmake/gen/hvx_microkernels.cmake b/cmake/gen/hvx_microkernels.cmake index 674492a4084..01d1edda5ef 100644 --- a/cmake/gen/hvx_microkernels.cmake +++ b/cmake/gen/hvx_microkernels.cmake @@ -163,6 +163,7 @@ SET(NON_PROD_HVX_MICROKERNEL_SRCS src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-hvx-rr2-p5-u128-acc4.c src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-hvx-u32.c src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-hvx-u64.c + src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-hvx.c src/f32-rminmax/gen/f32-rmax-hvx-u32.c src/f32-rminmax/gen/f32-rmax-hvx-u96-acc3.c src/f32-rminmax/gen/f32-rmax-hvx-u128-acc2.c diff --git a/cmake/gen/neon_microkernels.cmake b/cmake/gen/neon_microkernels.cmake index 63a5678ba2f..83d4178be19 100644 --- a/cmake/gen/neon_microkernels.cmake +++ b/cmake/gen/neon_microkernels.cmake @@ -41,10 +41,12 @@ SET(PROD_NEON_MICROKERNEL_SRCS src/f32-rdminmax/gen/f32-rdmax-2p2x-neon-u32.c src/f32-rdminmax/gen/f32-rdmin-2p2x-neon-u32.c src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-neon-u16.c + src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-neon.c src/f32-rminmax/gen/f32-rmax-neon-u16-acc4.c src/f32-rminmax/gen/f32-rmin-neon-u16-acc4.c src/f32-rminmax/gen/f32-rminmax-neon-u16-acc4.c src/f32-rsum/gen/f32-rsum-neon-u16-acc4.c + src/f32-rsum2/gen/f32-rsum2-neon.c src/f32-spmm/gen/f32-spmm-32x1-minmax-neon.c src/f32-vapproxgelu/gen/f32-vapproxgelu-neon-rational-12-10-div.c src/f32-vbinary/gen/f32-vadd-neon-u8.c diff --git a/cmake/gen/neonfp16arith_microkernels.cmake b/cmake/gen/neonfp16arith_microkernels.cmake index da4aca2fb4f..5d768995adc 100644 --- a/cmake/gen/neonfp16arith_microkernels.cmake +++ b/cmake/gen/neonfp16arith_microkernels.cmake @@ -22,7 +22,9 @@ SET(PROD_NEONFP16ARITH_MICROKERNEL_SRCS src/f16-dwconv2d-chw/gen/f16-dwconv2d-chw-5x5p2-minmax-neonfp16arith-1x8.c src/f16-dwconv2d-chw/gen/f16-dwconv2d-chw-5x5s2p2-minmax-neonfp16arith-1x8.c src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-minmax-neonfp16arith-u16.c + src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-minmax-neonfp16arith.c src/f16-f32acc-rsum/gen/f16-f32acc-rsum-neonfp16arith-u32-acc4.c + src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-neonfp16arith.c src/f16-gemm/gen/f16-gemm-1x8-minmax-neonfp16arith-ld64.c src/f16-gemm/gen/f16-gemm-1x16-minmax-neonfp16arith-ld64.c src/f16-gemm/gen/f16-gemm-6x8-minmax-neonfp16arith-ld64.c diff --git a/cmake/gen/scalar_microkernels.cmake b/cmake/gen/scalar_microkernels.cmake index 72801b82732..155c7f7f374 100644 --- a/cmake/gen/scalar_microkernels.cmake +++ b/cmake/gen/scalar_microkernels.cmake @@ -76,10 +76,12 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS src/f32-rdminmax/gen/f32-rdmax-2p2x-scalar-u2.c src/f32-rdminmax/gen/f32-rdmin-2p2x-scalar-u2.c src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-scalar.c + src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-scalar.c src/f32-rminmax/gen/f32-rmax-scalar-u4-acc4.c src/f32-rminmax/gen/f32-rmin-scalar-u4-acc4.c src/f32-rminmax/gen/f32-rminmax-scalar-u4-acc4.c src/f32-rsum/gen/f32-rsum-scalar-u4-acc4.c + src/f32-rsum2/gen/f32-rsum2-scalar-u1.c src/f32-spmm/gen/f32-spmm-8x1-minmax-scalar.c src/f32-spmm/gen/f32-spmm-8x2-minmax-scalar.c src/f32-spmm/gen/f32-spmm-8x4-minmax-scalar.c diff --git a/cmake/gen/sse2_microkernels.cmake b/cmake/gen/sse2_microkernels.cmake index 065763ed14a..f7f32d282bc 100644 --- a/cmake/gen/sse2_microkernels.cmake +++ b/cmake/gen/sse2_microkernels.cmake @@ -23,7 +23,9 @@ SET(PROD_SSE2_MICROKERNEL_SRCS src/f32-rdminmax/gen/f32-rdmax-2p2x-sse2-u32.c src/f32-rdminmax/gen/f32-rdmin-2p2x-sse2-u32.c src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-sse2-u16.c + src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-sse2.c src/f32-rsum/gen/f32-rsum-sse2-u16-acc4.c + src/f32-rsum2/gen/f32-rsum2-sse2-u4.c src/f32-vapproxgelu/gen/f32-vapproxgelu-sse2-rational-12-10-div.c src/f32-vbinary/gen/f32-vprelu-sse2-u8.c src/f32-vbinary/gen/f32-vpreluc-sse2-u8.c diff --git a/cmake/gen/wasmsimd_microkernels.cmake b/cmake/gen/wasmsimd_microkernels.cmake index 887e7558c2b..a3cafd40d90 100644 --- a/cmake/gen/wasmsimd_microkernels.cmake +++ b/cmake/gen/wasmsimd_microkernels.cmake @@ -94,10 +94,12 @@ SET(PROD_WASMSIMD_MICROKERNEL_SRCS src/f32-rdminmax/gen/f32-rdmax-2p2x-wasmsimd-u32.c src/f32-rdminmax/gen/f32-rdmin-2p2x-wasmsimd-u32.c src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-u16.c + src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-wasmsimd.c src/f32-rminmax/gen/f32-rmax-wasmsimd-pminmax-u16-acc4.c src/f32-rminmax/gen/f32-rmin-wasmsimd-minmax-u16-acc4.c src/f32-rminmax/gen/f32-rminmax-wasmsimd-minmax-u16-acc4.c src/f32-rsum/gen/f32-rsum-wasmsimd-u16-acc4.c + src/f32-rsum2/gen/f32-rsum2-wasmsimd-u4.c src/f32-spmm/gen/f32-spmm-32x1-minmax-wasmsimd-arm.c src/f32-spmm/gen/f32-spmm-32x1-minmax-wasmsimd-x86.c src/f32-vapproxgelu/gen/f32-vapproxgelu-wasmsimd-rational-12-10-div.c diff --git a/gen/avx512f_microkernels.bzl b/gen/avx512f_microkernels.bzl index 22e185fcdab..19f82ff9244 100644 --- a/gen/avx512f_microkernels.bzl +++ b/gen/avx512f_microkernels.bzl @@ -23,10 +23,12 @@ PROD_AVX512F_MICROKERNEL_SRCS = [ "src/f32-rdminmax/gen/f32-rdmax-2p2x-avx512f-u32.c", "src/f32-rdminmax/gen/f32-rdmin-2p2x-avx512f-u32.c", "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-u64.c", + "src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx512f.c", "src/f32-rminmax/gen/f32-rmax-avx512f-u64-acc4.c", "src/f32-rminmax/gen/f32-rmin-avx512f-u64-acc4.c", "src/f32-rminmax/gen/f32-rminmax-avx512f-u64-acc4.c", "src/f32-rsum/gen/f32-rsum-avx512f-u32-acc2.c", + "src/f32-rsum2/gen/f32-rsum2-avx512f-u16.c", "src/f32-vapproxgelu/gen/f32-vapproxgelu-avx512f-rational-12-10-div.c", "src/f32-vbinary/gen/f32-vadd-avx512f-u32.c", "src/f32-vbinary/gen/f32-vaddc-avx512f-u32.c", diff --git a/gen/avx512skx_microkernels.bzl b/gen/avx512skx_microkernels.bzl index c14d7ae9e96..4b0fa242cb8 100644 --- a/gen/avx512skx_microkernels.bzl +++ b/gen/avx512skx_microkernels.bzl @@ -8,7 +8,9 @@ Auto-generated file. Do not edit! PROD_AVX512SKX_MICROKERNEL_SRCS = [ "src/f16-f32-vcvt/gen/f16-f32-vcvt-avx512skx-u16.c", "src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-avx512skx-u64.c", + "src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-avx512skx.c", "src/f16-f32acc-rsum/gen/f16-f32acc-rsum-avx512skx-u32-acc2.c", + "src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-avx512skx.c", "src/f16-rminmax/gen/f16-rmax-avx512skx-u64-acc4.c", "src/f16-rminmax/gen/f16-rmin-avx512skx-u64-acc4.c", "src/f16-rminmax/gen/f16-rminmax-avx512skx-u64-acc4.c", diff --git a/gen/avx_microkernels.bzl b/gen/avx_microkernels.bzl index 399737f78e3..8f991ab5884 100644 --- a/gen/avx_microkernels.bzl +++ b/gen/avx_microkernels.bzl @@ -30,10 +30,12 @@ PROD_AVX_MICROKERNEL_SRCS = [ "src/f32-rdminmax/gen/f32-rdmax-2p2x-avx-u32.c", "src/f32-rdminmax/gen/f32-rdmin-2p2x-avx-u32.c", "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-u32.c", + "src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx.c", "src/f32-rminmax/gen/f32-rmax-avx-u32-acc4.c", "src/f32-rminmax/gen/f32-rmin-avx-u32-acc4.c", "src/f32-rminmax/gen/f32-rminmax-avx-u32-acc4.c", "src/f32-rsum/gen/f32-rsum-avx-u32-acc4.c", + "src/f32-rsum2/gen/f32-rsum2-avx-u8.c", "src/f32-vapproxgelu/gen/f32-vapproxgelu-avx-rational-12-10-div.c", "src/f32-vbinary/gen/f32-vadd-avx-u16.c", "src/f32-vbinary/gen/f32-vaddc-avx-u16.c", diff --git a/gen/f16c_microkernels.bzl b/gen/f16c_microkernels.bzl index 25468c21b82..c6996c35b16 100644 --- a/gen/f16c_microkernels.bzl +++ b/gen/f16c_microkernels.bzl @@ -9,7 +9,9 @@ PROD_F16C_MICROKERNEL_SRCS = [ "src/f16-avgpool/gen/f16-avgpool-9p-minmax-f16c.c", "src/f16-f32-vcvt/gen/f16-f32-vcvt-f16c-u16.c", "src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-f16c-u32.c", + "src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-f16c.c", "src/f16-f32acc-rsum/gen/f16-f32acc-rsum-f16c-u32-acc4.c", + "src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-f16c.c", "src/f16-rminmax/f16-rmax-f16c-u32.c", "src/f16-vbinary/gen/f16-vadd-f16c-u16.c", "src/f16-vbinary/gen/f16-vaddc-f16c-u16.c", diff --git a/gen/hvx_microkernels.bzl b/gen/hvx_microkernels.bzl index 5afce5f0bb6..37c65d5724b 100644 --- a/gen/hvx_microkernels.bzl +++ b/gen/hvx_microkernels.bzl @@ -160,6 +160,7 @@ NON_PROD_HVX_MICROKERNEL_SRCS = [ "src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-hvx-rr2-p5-u128-acc4.c", "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-hvx-u32.c", "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-hvx-u64.c", + "src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-hvx.c", "src/f32-rminmax/gen/f32-rmax-hvx-u32.c", "src/f32-rminmax/gen/f32-rmax-hvx-u96-acc3.c", "src/f32-rminmax/gen/f32-rmax-hvx-u128-acc2.c", diff --git a/gen/neon_microkernels.bzl b/gen/neon_microkernels.bzl index 7b10158019e..266ecab9391 100644 --- a/gen/neon_microkernels.bzl +++ b/gen/neon_microkernels.bzl @@ -37,10 +37,12 @@ PROD_NEON_MICROKERNEL_SRCS = [ "src/f32-rdminmax/gen/f32-rdmax-2p2x-neon-u32.c", "src/f32-rdminmax/gen/f32-rdmin-2p2x-neon-u32.c", "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-neon-u16.c", + "src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-neon.c", "src/f32-rminmax/gen/f32-rmax-neon-u16-acc4.c", "src/f32-rminmax/gen/f32-rmin-neon-u16-acc4.c", "src/f32-rminmax/gen/f32-rminmax-neon-u16-acc4.c", "src/f32-rsum/gen/f32-rsum-neon-u16-acc4.c", + "src/f32-rsum2/gen/f32-rsum2-neon.c", "src/f32-spmm/gen/f32-spmm-32x1-minmax-neon.c", "src/f32-vapproxgelu/gen/f32-vapproxgelu-neon-rational-12-10-div.c", "src/f32-vbinary/gen/f32-vadd-neon-u8.c", diff --git a/gen/neonfp16arith_microkernels.bzl b/gen/neonfp16arith_microkernels.bzl index ac014bd69e7..500208b328a 100644 --- a/gen/neonfp16arith_microkernels.bzl +++ b/gen/neonfp16arith_microkernels.bzl @@ -18,7 +18,9 @@ PROD_NEONFP16ARITH_MICROKERNEL_SRCS = [ "src/f16-dwconv2d-chw/gen/f16-dwconv2d-chw-5x5p2-minmax-neonfp16arith-1x8.c", "src/f16-dwconv2d-chw/gen/f16-dwconv2d-chw-5x5s2p2-minmax-neonfp16arith-1x8.c", "src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-minmax-neonfp16arith-u16.c", + "src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-minmax-neonfp16arith.c", "src/f16-f32acc-rsum/gen/f16-f32acc-rsum-neonfp16arith-u32-acc4.c", + "src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-neonfp16arith.c", "src/f16-gemm/gen/f16-gemm-1x8-minmax-neonfp16arith-ld64.c", "src/f16-gemm/gen/f16-gemm-1x16-minmax-neonfp16arith-ld64.c", "src/f16-gemm/gen/f16-gemm-6x8-minmax-neonfp16arith-ld64.c", diff --git a/gen/scalar_microkernels.bzl b/gen/scalar_microkernels.bzl index c538ae1e86c..66e5f2b0450 100644 --- a/gen/scalar_microkernels.bzl +++ b/gen/scalar_microkernels.bzl @@ -72,10 +72,12 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ "src/f32-rdminmax/gen/f32-rdmax-2p2x-scalar-u2.c", "src/f32-rdminmax/gen/f32-rdmin-2p2x-scalar-u2.c", "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-scalar.c", + "src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-scalar.c", "src/f32-rminmax/gen/f32-rmax-scalar-u4-acc4.c", "src/f32-rminmax/gen/f32-rmin-scalar-u4-acc4.c", "src/f32-rminmax/gen/f32-rminmax-scalar-u4-acc4.c", "src/f32-rsum/gen/f32-rsum-scalar-u4-acc4.c", + "src/f32-rsum2/gen/f32-rsum2-scalar-u1.c", "src/f32-spmm/gen/f32-spmm-8x1-minmax-scalar.c", "src/f32-spmm/gen/f32-spmm-8x2-minmax-scalar.c", "src/f32-spmm/gen/f32-spmm-8x4-minmax-scalar.c", diff --git a/gen/sse2_microkernels.bzl b/gen/sse2_microkernels.bzl index b3fa6513154..b4fa475f050 100644 --- a/gen/sse2_microkernels.bzl +++ b/gen/sse2_microkernels.bzl @@ -19,7 +19,9 @@ PROD_SSE2_MICROKERNEL_SRCS = [ "src/f32-rdminmax/gen/f32-rdmax-2p2x-sse2-u32.c", "src/f32-rdminmax/gen/f32-rdmin-2p2x-sse2-u32.c", "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-sse2-u16.c", + "src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-sse2.c", "src/f32-rsum/gen/f32-rsum-sse2-u16-acc4.c", + "src/f32-rsum2/gen/f32-rsum2-sse2-u4.c", "src/f32-vapproxgelu/gen/f32-vapproxgelu-sse2-rational-12-10-div.c", "src/f32-vbinary/gen/f32-vprelu-sse2-u8.c", "src/f32-vbinary/gen/f32-vpreluc-sse2-u8.c", diff --git a/gen/wasmsimd_microkernels.bzl b/gen/wasmsimd_microkernels.bzl index 0875a41e866..841efca55be 100644 --- a/gen/wasmsimd_microkernels.bzl +++ b/gen/wasmsimd_microkernels.bzl @@ -90,10 +90,12 @@ PROD_WASMSIMD_MICROKERNEL_SRCS = [ "src/f32-rdminmax/gen/f32-rdmax-2p2x-wasmsimd-u32.c", "src/f32-rdminmax/gen/f32-rdmin-2p2x-wasmsimd-u32.c", "src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-wasmsimd-u16.c", + "src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-wasmsimd.c", "src/f32-rminmax/gen/f32-rmax-wasmsimd-pminmax-u16-acc4.c", "src/f32-rminmax/gen/f32-rmin-wasmsimd-minmax-u16-acc4.c", "src/f32-rminmax/gen/f32-rminmax-wasmsimd-minmax-u16-acc4.c", "src/f32-rsum/gen/f32-rsum-wasmsimd-u16-acc4.c", + "src/f32-rsum2/gen/f32-rsum2-wasmsimd-u4.c", "src/f32-spmm/gen/f32-spmm-32x1-minmax-wasmsimd-arm.c", "src/f32-spmm/gen/f32-spmm-32x1-minmax-wasmsimd-x86.c", "src/f32-vapproxgelu/gen/f32-vapproxgelu-wasmsimd-rational-12-10-div.c", diff --git a/include/xnnpack.h b/include/xnnpack.h index 2a323fbb70c..de6bae60d69 100644 --- a/include/xnnpack.h +++ b/include/xnnpack.h @@ -1468,6 +1468,7 @@ XNN_DEPRECATED enum xnn_status xnn_define_static_mean( enum xnn_reduce_operator { xnn_reduce_invalid = -1, xnn_reduce_sum, + xnn_reduce_sum_squared, xnn_reduce_mean, xnn_reduce_max, xnn_reduce_min, @@ -2120,6 +2121,38 @@ enum xnn_status xnn_define_softmax( uint32_t output_id, uint32_t flags); +/// Type of normalize operation +enum xnn_norm_type { + xnn_norm_invalid = -1, + xnn_norm_l2 = 0, + xnn_norm_rms = 1, +}; + +/// Define a Normalization Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be +/// defined in the @a subgraph, and have at least one +/// dimension. +/// @param scale_id - Optional value ID for the scale tensor. The 1D scale +/// tensor must be defined in the @a subgraph, and have the +/// same length as the input tensor's last dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be +/// defined in the @a subgraph, and its shape must match the +/// shape of the input tensor. +/// @param epsilon - The value by which to offset the mean of the squares +/// before inversion (usually ~1e-6). +/// @param flags - binary features of the RMSNorm Node. No supported flags +/// are currently defined. +enum xnn_status xnn_define_normalize( + xnn_subgraph_t subgraph, + enum xnn_norm_type norm_type, + uint32_t input_id, + uint32_t scale_id, + uint32_t output_id, + float epsilon, + uint32_t flags); + /// Define a Space To Depth 2D Node and add it to a Subgraph. /// /// The Space To Depth 2D Node rearranges blocks of spatial data into blocks (a reverse transform to Depth To Space 2D). @@ -4455,6 +4488,46 @@ enum xnn_status xnn_setup_resize_bilinear2d_nhwc( const void* input, void* output); +enum xnn_status xnn_create_normalize_nc_f32( + enum xnn_norm_type norm_type, + float epsilon, + uint32_t flags, + xnn_operator_t* normalize_op_out); + +enum xnn_status xnn_reshape_normalize_nc_f32( + xnn_operator_t normalize_op, + size_t channels, + size_t input_stride, + size_t output_stride, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_normalize_nc_f32( + xnn_operator_t normalize_op, + const float* input, + const float* scale, + float* output); + +enum xnn_status xnn_create_normalize_nc_f16( + enum xnn_norm_type norm_type, + float epsilon, + uint32_t flags, + xnn_operator_t* normalize_op_out); + +enum xnn_status xnn_reshape_normalize_nc_f16( + xnn_operator_t normalize_op, + size_t channels, + size_t input_stride, + size_t output_stride, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_normalize_nc_f16( + xnn_operator_t normalize_op, + const void* input, + const void* scale, + void* output); + enum xnn_status xnn_create_rope_nthc_f16( uint32_t flags, xnn_operator_t* rope_op_out); diff --git a/scripts/generate-f16-f32acc-rdsum.sh b/scripts/generate-f16-f32acc-rdsum.sh index 304c785800e..79e69671755 100755 --- a/scripts/generate-f16-f32acc-rdsum.sh +++ b/scripts/generate-f16-f32acc-rdsum.sh @@ -20,9 +20,5 @@ tools/xngen src/f16-f32acc-rdsum/avx512skx.c.in -D CHANNELS_BATCH=16 -D ACCUMULA tools/xngen src/f16-f32acc-rdsum/avx512skx.c.in -D CHANNELS_BATCH=32 -D ACCUMULATORS=7 -o src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-avx512skx-u32.c & tools/xngen src/f16-f32acc-rdsum/avx512skx.c.in -D CHANNELS_BATCH=64 -D ACCUMULATORS=7 -o src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-avx512skx-u64.c & tools/xngen src/f16-f32acc-rdsum/avx512skx.c.in -D CHANNELS_BATCH=128 -D ACCUMULATORS=7 -o src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-avx512skx-u128.c & -tools/xngen src/f16-f32acc-rdsum/avx512skx.c.in -D CHANNELS_BATCH=16 -D ACCUMULATORS=7 -o src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-avx512skx-u16.c & -tools/xngen src/f16-f32acc-rdsum/avx512skx.c.in -D CHANNELS_BATCH=32 -D ACCUMULATORS=7 -o src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-avx512skx-u32.c & -tools/xngen src/f16-f32acc-rdsum/avx512skx.c.in -D CHANNELS_BATCH=64 -D ACCUMULATORS=7 -o src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-avx512skx-u64.c & -tools/xngen src/f16-f32acc-rdsum/avx512skx.c.in -D CHANNELS_BATCH=128 -D ACCUMULATORS=7 -o src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-avx512skx-u128.c & wait diff --git a/scripts/generate-f16-f32acc-rdsum2.sh b/scripts/generate-f16-f32acc-rdsum2.sh new file mode 100755 index 00000000000..66b874bab68 --- /dev/null +++ b/scripts/generate-f16-f32acc-rdsum2.sh @@ -0,0 +1,16 @@ +#!/bin/sh +# Copyright 2025 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#################################### NEON ##################################### +tools/xngen src/f16-f32acc-rdsum2/neon.c.in -D CHANNELS_BATCHES=16,32,64 -D ACCUMULATORS=7 -o src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-minmax-neonfp16arith.c & + +################################## x86 AVX #################################### +tools/xngen src/f16-f32acc-rdsum2/avx.c.in -D CHANNELS_BATCHES=16,32,64,128 -D ACCUMULATORS=7 -o src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-f16c.c & + +################################## x86 AVX512 ################################# +tools/xngen src/f16-f32acc-rdsum2/avx512skx.c.in -D CHANNELS_BATCHES=16,32,64,128 -D ACCUMULATORS=7 -o src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-avx512skx.c & + +wait diff --git a/scripts/generate-f16-f32acc-rsum2.sh b/scripts/generate-f16-f32acc-rsum2.sh new file mode 100755 index 00000000000..3b249d56439 --- /dev/null +++ b/scripts/generate-f16-f32acc-rsum2.sh @@ -0,0 +1,16 @@ +#!/bin/sh +# Copyright 2025 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +################################# ARM NEONFP16 ################################ +tools/xngen src/f16-f32acc-rsum2/neonfp16arith.c.in -D BATCH_TILES=8,16,24,32 -o src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-neonfp16arith.c & + +################################### x86 F16C ################################## +tools/xngen src/f16-f32acc-rsum2/f16c.c.in -D BATCH_TILES=8,16,24,32 -o src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-f16c.c & + +################################## x86 AVX512 ################################# +tools/xngen src/f16-f32acc-rsum2/avx512skx.c.in -D BATCH_TILES=16,32,48,64,128 -o src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-avx512skx.c & + +wait diff --git a/scripts/generate-f32-rdsum2.sh b/scripts/generate-f32-rdsum2.sh new file mode 100755 index 00000000000..fec0f5155ce --- /dev/null +++ b/scripts/generate-f32-rdsum2.sh @@ -0,0 +1,16 @@ +#!/bin/sh +# Copyright 2025 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +################################# SIMD Wrappers ################################ +tools/xngen src/f32-rdsum2/simd.c.in -D ARCH=scalar -D SIMD_SIZE=1 -D CHANNELS=4 -D ACCUMULATORS=7 -o src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-scalar.c & +tools/xngen src/f32-rdsum2/simd.c.in -D ARCH=neon -D SIMD_SIZE=4 -D CHANNELS=16,32,64 -D ACCUMULATORS=7 -o src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-neon.c & +tools/xngen src/f32-rdsum2/simd.c.in -D ARCH=sse2 -D SIMD_SIZE=4 -D CHANNELS=16,32,64 -D ACCUMULATORS=7 -o src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-sse2.c & +tools/xngen src/f32-rdsum2/simd.c.in -D ARCH=avx -D SIMD_SIZE=8 -D CHANNELS=16,32,64 -D ACCUMULATORS=7 -o src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx.c & +tools/xngen src/f32-rdsum2/simd.c.in -D ARCH=avx512f -D SIMD_SIZE=16 -D CHANNELS=16,32,64,128 -D ACCUMULATORS=7 -o src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx512f.c & +tools/xngen src/f32-rdsum2/simd.c.in -D ARCH=hvx -D SIMD_SIZE=32 -D CHANNELS=32,64,128 -D ACCUMULATORS=7 -o src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-hvx.c & +tools/xngen src/f32-rdsum2/simd.c.in -D ARCH=wasmsimd -D SIMD_SIZE=4 -D CHANNELS=16,32,64 -D ACCUMULATORS=7 -o src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-wasmsimd.c & + +wait diff --git a/scripts/generate-f32-rsum2.sh b/scripts/generate-f32-rsum2.sh new file mode 100755 index 00000000000..732a94dd587 --- /dev/null +++ b/scripts/generate-f32-rsum2.sh @@ -0,0 +1,15 @@ +#!/bin/sh +# Copyright 2025 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +################################# SIMD Wrappers ################################ +tools/xngen src/f32-rsum2/simd.c.in -D ARCH=neon -D BATCH_TILES=4,8,12,16 -o src/f32-rsum2/gen/f32-rsum2-neon.c & +tools/xngen src/f32-rsum2/simd.c.in -D ARCH=sse2 -D BATCH_TILES=4,8,12,16 -o src/f32-rsum2/gen/f32-rsum2-sse2-u4.c & +tools/xngen src/f32-rsum2/simd.c.in -D ARCH=avx -D BATCH_TILES=8,16,24,32 -o src/f32-rsum2/gen/f32-rsum2-avx-u8.c & +tools/xngen src/f32-rsum2/simd.c.in -D ARCH=avx512f -D BATCH_TILES=16,32,48,64 -o src/f32-rsum2/gen/f32-rsum2-avx512f-u16.c & +tools/xngen src/f32-rsum2/simd.c.in -D ARCH=wasmsimd -D BATCH_TILES=4,8,12,16 -o src/f32-rsum2/gen/f32-rsum2-wasmsimd-u4.c & +tools/xngen src/f32-rsum2/simd.c.in -D ARCH=scalar -D BATCH_TILES=1,2,3,4 -o src/f32-rsum2/gen/f32-rsum2-scalar-u1.c & + +wait diff --git a/src/configs/reduce-config.c b/src/configs/reduce-config.c index bcd1b1b8b6f..c6bccfc67cf 100644 --- a/src/configs/reduce-config.c +++ b/src/configs/reduce-config.c @@ -20,6 +20,7 @@ #include "src/xnnpack/reduce.h" static struct xnn_reduce_config f16_f32acc_rsum_config = {0}; +static struct xnn_reduce_config f16_f32acc_rsum2_config = {0}; static struct xnn_reduce_config f16_rmax_config = {0}; static struct xnn_reduce_config f16_rminmax_config = {0}; static struct xnn_reduce_config f16_rmin_config = {0}; @@ -27,6 +28,7 @@ static struct xnn_reduce_config f32_rmax_config = {0}; static struct xnn_reduce_config f32_rminmax_config = {0}; static struct xnn_reduce_config f32_rmin_config = {0}; static struct xnn_reduce_config f32_rsum_config = {0}; +static struct xnn_reduce_config f32_rsum2_config = {0}; static struct xnn_reduce_config s8_rmax_config = {0}; static struct xnn_reduce_config s8_rminmax_config = {0}; static struct xnn_reduce_config s8_rmin_config = {0}; @@ -37,6 +39,7 @@ static struct xnn_reduce_config u8_rmin_config = {0}; static struct xnn_reduce_config qu8_rsum_config = {0}; XNN_INIT_ONCE_GUARD(f16_f32acc_rsum); +XNN_INIT_ONCE_GUARD(f16_f32acc_rsum2); XNN_INIT_ONCE_GUARD(f16_rmax); XNN_INIT_ONCE_GUARD(f16_rminmax); XNN_INIT_ONCE_GUARD(f16_rmin); @@ -44,6 +47,7 @@ XNN_INIT_ONCE_GUARD(f32_rmax); XNN_INIT_ONCE_GUARD(f32_rminmax); XNN_INIT_ONCE_GUARD(f32_rmin); XNN_INIT_ONCE_GUARD(f32_rsum); +XNN_INIT_ONCE_GUARD(f32_rsum2); XNN_INIT_ONCE_GUARD(s8_rmax); XNN_INIT_ONCE_GUARD(s8_rminmax); XNN_INIT_ONCE_GUARD(s8_rmin); @@ -451,6 +455,43 @@ static void init_f16_f32acc_rsum_config(void) { f16_f32acc_rsum_config.update = xnn_update_f32_reduce_scalar_params; } +static void init_f16_f32acc_rsum2_config(void) { + #if (XNN_ARCH_ARM || XNN_ARCH_ARM64) && XNN_ENABLE_ARM_FP16_VECTOR + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + if ((hardware_config->arch_flags & xnn_arch_arm_neon_fp16_arith)) { + f16_f32acc_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u32_acc4); + f16_f32acc_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f16_f32acc_rdsum2_ukernel_7p7x__neonfp16arith_u16); + } + #elif (XNN_ARCH_X86 || XNN_ARCH_X86_64) + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + #if XNN_ENABLE_AVX512SKX + if ((hardware_config->arch_flags & xnn_arch_x86_avx512skx)) { + // We use a kernel with the same unroll factor as avx, because that + // produces numerically consistent results at negligible performance + // cost. + f16_f32acc_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f16_f32acc_rsum2_ukernel__avx512skx_u32_acc2); + } else + #endif + if ((hardware_config->arch_flags & xnn_arch_x86_f16c)) { + f16_f32acc_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f16_f32acc_rsum2_ukernel__f16c_u24_acc3); + } + #if XNN_ENABLE_AVX512SKX + if ((hardware_config->arch_flags & xnn_arch_x86_avx512skx)) { + f16_f32acc_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f16_f32acc_rdsum2_ukernel_7p7x__avx512skx_u64); + } else + #endif + if ((hardware_config->arch_flags & xnn_arch_x86_f16c)) { + f16_f32acc_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f16_f32acc_rdsum2_ukernel_7p7x__f16c_u32); + } + #endif + + f16_f32acc_rsum2_config.identity_value = 0; + f16_f32acc_rsum2_config.init.reduce = NULL; + f16_f32acc_rsum2_config.update = xnn_update_f32_reduce_scalar_params; +} + static void init_f16_rmax_config(void) { #if (XNN_ARCH_ARM || XNN_ARCH_ARM64) && XNN_ENABLE_ARM_FP16_VECTOR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -765,6 +806,61 @@ static void init_f32_rsum_config(void) { f32_rsum_config.update = xnn_update_f32_reduce_scalar_params; } +static void init_f32_rsum2_config(void) { + #if XNN_ARCH_ARM + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + if ((hardware_config->arch_flags & xnn_arch_arm_neon)) { + f32_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f32_rsum2_ukernel__neon_u16_acc4); + f32_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f32_rdsum2_ukernel_7p7x__neon_u16); + } else { + f32_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f32_rsum2_ukernel__scalar_u4_acc4); + f32_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f32_rdsum2_ukernel_7p7x__scalar_u4); + } + #elif XNN_ARCH_ARM64 + f32_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f32_rsum2_ukernel__neon_u16_acc4); + f32_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f32_rdsum2_ukernel_7p7x__neon_u16); + #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + #if XNN_ENABLE_AVX512F + if ((hardware_config->arch_flags & xnn_arch_x86_avx512f)) { + // We use a kernel with the same unroll factor as avx, because that + // produces numerically consistent results at negligible performance + // cost. + f32_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f32_rsum2_ukernel__avx512f_u64_acc4); + } else + #endif + if ((hardware_config->arch_flags & xnn_arch_x86_avx)) { + f32_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f32_rsum2_ukernel__avx_u32_acc4); + } else { + // A hypothetical u32_acc8 kernel would produce results numerically + // consistent with avx and avx512f. + f32_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f32_rsum2_ukernel__sse2_u16_acc4); + } + #if XNN_ENABLE_AVX512F + if ((hardware_config->arch_flags & xnn_arch_x86_avx512f)) { + f32_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f32_rdsum2_ukernel_7p7x__avx512f_u64); + } else + #endif + if ((hardware_config->arch_flags & xnn_arch_x86_avx)) { + f32_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f32_rdsum2_ukernel_7p7x__avx_u32); + } else { + f32_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f32_rdsum2_ukernel_7p7x__sse2_u16); + } + #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + f32_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f32_rsum2_ukernel__wasmsimd_u16_acc4); + f32_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f32_rdsum2_ukernel_7p7x__wasmsimd_u16); + #else + f32_rsum2_config.ukernel = XNN_INIT_REDUCE_UKERNEL(xnn_f32_rsum2_ukernel__scalar_u4_acc4); + f32_rsum2_config.rd_ukernel2 = XNN_INIT_REDUCE_DISCONTIGUOUS_UKERNEL2(xnn_f32_rdsum2_ukernel_7p7x__scalar_u4); + #endif + + f32_rsum2_config.identity_value = 0; + f32_rsum2_config.init.reduce = NULL; + f32_rsum2_config.update = xnn_update_f32_reduce_scalar_params; +} + const struct xnn_reduce_config* xnn_init_f16_f32acc_rsum_config() { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); if (hardware_config == NULL || !xnn_is_f16_compatible_config(hardware_config)) { @@ -774,6 +870,15 @@ const struct xnn_reduce_config* xnn_init_f16_f32acc_rsum_config() { return &f16_f32acc_rsum_config; } +const struct xnn_reduce_config* xnn_init_f16_f32acc_rsum2_config() { + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + if (hardware_config == NULL || !xnn_is_f16_compatible_config(hardware_config)) { + return NULL; + } + XNN_INIT_ONCE(f16_f32acc_rsum2); + return &f16_f32acc_rsum2_config; +} + const struct xnn_reduce_config* xnn_init_f16_rmax_config() { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); if (hardware_config == NULL) { @@ -837,6 +942,15 @@ const struct xnn_reduce_config* xnn_init_f32_rsum_config() { return &f32_rsum_config; } +const struct xnn_reduce_config* xnn_init_f32_rsum2_config() { + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + if (hardware_config == NULL) { + return NULL; + } + XNN_INIT_ONCE(f32_rsum2); + return &f32_rsum2_config; +} + const struct xnn_reduce_config* xnn_init_s8_rmax_config() { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); if (hardware_config == NULL) { diff --git a/src/f16-f32acc-rdsum2/avx.c.in b/src/f16-f32acc-rdsum2/avx.c.in new file mode 100644 index 00000000000..71410896e5d --- /dev/null +++ b/src/f16-f32acc-rdsum2/avx.c.in @@ -0,0 +1,175 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include + +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/unaligned.h" +#include "src/xnnpack/common.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/math.h" + +$CHANNELS_BATCHES = tuple(int(cb) for cb in CHANNELS_BATCHES.split(",")) +$for CHANNELS_BATCH in CHANNELS_BATCHES: + $UNROLL = CHANNELS_BATCH >> 3 + + void xnn_f16_f32acc_rdsum2_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__f16c_u${CHANNELS_BATCH}( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) + { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = ${ACCUMULATORS} * input_stride1; + for (; channels >= ${CHANNELS_BATCH}; channels -= ${CHANNELS_BATCH}) { + const uint16_t* i0 = (const uint16_t*) input_row; + $for ACC in range(1, ACCUMULATORS): + const uint16_t* i${ACC} = (const uint16_t*) ((uintptr_t) input_row + ${ACC} * input_stride1); + + $for i in range(UNROLL): + __m256 vacc${i} = _mm256_setzero_ps(); + + for (int r = k1; r > 0; r -= ${ACCUMULATORS}) { + $for ACC in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${ACC+1}) { + i${ACC} = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= ${ACC+1}) { + i${ACC+1} = (const uint16_t*) zero; + } + $for c in range(UNROLL): + __m256 vin${c}; + $for j in range(ACCUMULATORS): + $for c in range(UNROLL): + vin${c} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i${j}[${c*8}]))); + $for c in range(UNROLL): + vin${c} = _mm256_mul_ps(vin${c}, vin${c}); + $for c in range(UNROLL): + vacc${c} = _mm256_add_ps(vin${c}, vacc${c}); + $for ACC in range(0, ACCUMULATORS): + i${ACC} = (const uint16_t*) ((uintptr_t) i${ACC} + input_increment); + } + $for i in range(UNROLL): + vacc${i} = _mm256_mul_ps(vacc${i}, vscale); + + const float* o = output; + $for i in range(0, UNROLL): + __m256 vo${i} = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + $for i in range(0, UNROLL): + vacc${i} = _mm256_add_ps(vo${i}, vacc${i}); + $for i in range(0, UNROLL): + _mm256_storeu_ps(output, vacc${i}); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + ${CHANNELS_BATCH} * sizeof(xnn_float16)); + } + if (channels != 0) { + input_increment = ${ACCUMULATORS} * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + $for ACC in range(1, ACCUMULATORS): + const uint16_t* i${ACC} = (const uint16_t*) ((uintptr_t) input_row + ${ACC} * input_stride1); + __m256 vacc[${UNROLL}]; + $for i in range(UNROLL): + vacc[${i}] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = k1; r > 0; r -= ${ACCUMULATORS}) { + $for ACC in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${ACC+1}) { + i${ACC} = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= ${ACC+1}) { + i${ACC+1} = (const uint16_t*) zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + $for c in range(ACCUMULATORS): + __m256 vin${c} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i${c}[i*8])); + $for c in range(ACCUMULATORS): + vin${c} = _mm256_mul_ps(vin${c}, vin${c}); + $for c in range(ACCUMULATORS): + vacc[i] = _mm256_add_ps(vin${c}, vacc[i]); + } + + if (remainder) { + $for c in range(ACCUMULATORS): + __m256 vin${c} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i${c}[num_full_chunks*8])); + $for c in range(ACCUMULATORS): + vin${c} = _mm256_mul_ps(vin${c}, vin${c}); + $for c in range(ACCUMULATORS): + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin${c}); + } + $for ACC in range(ACCUMULATORS): + i${ACC} = (const uint16_t*) ((uintptr_t) i${ACC} + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[${UNROLL}]; + const float* o = output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm256_storeu_ps(output, vacc[i]); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + } + if (remainder) { + __m256 vout = vacc[num_full_chunks]; + __m128 vout_low = _mm256_castps256_ps128(vout); + if (channels & 4) { + __m128 vo = _mm_loadu_ps(output); + vo = _mm_add_ps(vout_low, vo); + _mm_storeu_ps(output, vo); + vout_low = _mm256_castps256_ps128(_mm256_permute2f128_ps(vout, vout, 1)); + output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + } + if (channels & 2) { + __m128 vo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i*) output)); + vo = _mm_add_ps(vout_low, vo); + _mm_storel_pi((__m64*) output, vo); + vout_low = _mm_movehl_ps(vout_low, vout_low); + output = (void*) ((uintptr_t) output + 2 * sizeof(float)); + } + if (channels & 1) { + __m128 vo = _mm_castsi128_ps(_mm_cvtsi32_si128(unaligned_load_s32(output))); + vo = _mm_add_ps(vout_low, vo); + _mm_store_ss(output, vo); + } + } + } + } + } + } diff --git a/src/f16-f32acc-rdsum2/avx512skx.c.in b/src/f16-f32acc-rdsum2/avx512skx.c.in new file mode 100644 index 00000000000..35f2fe2a381 --- /dev/null +++ b/src/f16-f32acc-rdsum2/avx512skx.c.in @@ -0,0 +1,158 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/intrinsics-polyfill.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" + +$CHANNELS_BATCHES = tuple(int(cb) for cb in CHANNELS_BATCHES.split(",")) +$for CHANNELS_BATCH in CHANNELS_BATCHES: + $UNROLL = CHANNELS_BATCH >> 4 + + void xnn_f16_f32acc_rdsum2_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx512skx_u${CHANNELS_BATCH}( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) + { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = ${ACCUMULATORS} * input_stride1; + for (; channels >= ${CHANNELS_BATCH}; channels -= ${CHANNELS_BATCH}) { + const uint16_t* i0 = (const uint16_t*) input_row; + $for ACC in range(1, ACCUMULATORS): + const uint16_t* i${ACC} = (const uint16_t*) ((uintptr_t) input_row + ${ACC} * input_stride1); + + $for i in range(UNROLL): + __m512 vacc${i} = _mm512_setzero_ps(); + + for (int r = k1; r > 0; r -= ${ACCUMULATORS}) { + $for ACC in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${ACC+1}) { + i${ACC} = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= ${ACC+1}) { + i${ACC+1} = (const uint16_t*) zero; + } + $for c in range(UNROLL): + __m512 vin${c}; + $for j in range(ACCUMULATORS): + $for c in range(UNROLL): + vin${c} = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i${j}[${c*16}]))); + $for c in range(UNROLL): + vacc${c} = _mm512_fmadd_ps(vin${c}, vin${c}, vacc${c}); + $for ACC in range(0, ACCUMULATORS): + i${ACC} = (const uint16_t*) ((uintptr_t) i${ACC} + input_increment); + } + $for i in range(UNROLL): + vacc${i} = _mm512_mul_ps(vacc${i}, vscale); + + $for i in range(0, UNROLL): + __m512 vo${i} = _mm512_loadu_ps(output + ${i} * 16); + $for i in range(0, UNROLL): + vacc${i} = _mm512_add_ps(vo${i}, vacc${i}); + $for i in range(0, UNROLL): + _mm512_storeu_ps(output, vacc${i}); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + ${CHANNELS_BATCH} * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = ${ACCUMULATORS} * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + $for i in range(1, ACCUMULATORS): + const uint16_t* i${i} = (const uint16_t*) ((uintptr_t) input_row + ${i} * input_stride1); + __m512 vacc[${UNROLL}]; + $for i in range(UNROLL): + vacc[${i}] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + // AVX512 has 16 float lanes. + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + // 0xF masks the remainder. + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask; + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = k1; r > 0; r -= ${ACCUMULATORS}) { + $for ACC in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${ACC+1}) { + i${ACC} = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= ${ACC+1}) { + i${ACC+1} = (const uint16_t*) zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + $for c in range(ACCUMULATORS): + __m512 vin${c} = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i${c}[i*16])); + $for c in range(ACCUMULATORS): + vacc[i] = _mm512_fmadd_ps(vin${c}, vin${c}, vacc[i]); + } + + if (remainder) { + $for c in range(ACCUMULATORS): + __m512 vin${c} = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i${c}[num_full_chunks*16])); + $for c in range(ACCUMULATORS): + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin${c}, vin${c}, vacc[num_full_chunks]); + } + $for ACC in range(ACCUMULATORS): + i${ACC} = (const uint16_t*) ((uintptr_t) i${ACC} + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[${UNROLL}]; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm512_loadu_ps(output + i * 16); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm512_storeu_ps(output, vacc[i]); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + } + if (remainder) { + __m512 vout = vacc[num_full_chunks]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } + } + } + } diff --git a/src/f16-f32acc-rdsum2/f16-f32acc-rdsum2.inc b/src/f16-f32acc-rdsum2/f16-f32acc-rdsum2.inc new file mode 100644 index 00000000000..8ee5ead3c67 --- /dev/null +++ b/src/f16-f32acc-rdsum2/f16-f32acc-rdsum2.inc @@ -0,0 +1,26 @@ +// clang-format off +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#if XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rdsum2_ukernel_7p7x__neonfp16arith_u16, 7, 16, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rdsum2_ukernel_7p7x__neonfp16arith_u32, 7, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rdsum2_ukernel_7p7x__neonfp16arith_u64, 7, 64, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +#endif // XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rdsum2_ukernel_7p7x__f16c_u16, 7, 16, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rdsum2_ukernel_7p7x__f16c_u32, 7, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rdsum2_ukernel_7p7x__f16c_u64, 7, 64, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rdsum2_ukernel_7p7x__f16c_u128, 7, 128, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + +#if XNN_ENABLE_AVX512SKX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rdsum2_ukernel_7p7x__avx512skx_u16, 7, 16, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rdsum2_ukernel_7p7x__avx512skx_u32, 7, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rdsum2_ukernel_7p7x__avx512skx_u64, 7, 64, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rdsum2_ukernel_7p7x__avx512skx_u128, 7, 128, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +#endif // XNN_ENABLE_AVX512SKX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) + diff --git a/src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-avx512skx.c b/src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-avx512skx.c new file mode 100644 index 00000000000..3ce82c94f23 --- /dev/null +++ b/src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-avx512skx.c @@ -0,0 +1,1053 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f16-f32acc-rdsum2/avx512skx.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/intrinsics-polyfill.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" + + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__avx512skx_u16( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 16; channels -= 16) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + __m512 vacc0 = _mm512_setzero_ps(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + __m512 vin0; + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[0]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[0]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[0]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[0]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[0]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[0]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[0]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm512_mul_ps(vacc0, vscale); + + __m512 vo0 = _mm512_loadu_ps(output + 0 * 16); + vacc0 = _mm512_add_ps(vo0, vacc0); + _mm512_storeu_ps(output, vacc0); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 16 * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + __m512 vacc[1]; + vacc[0] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + // AVX512 has 16 float lanes. + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + // 0xF masks the remainder. + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask; + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + __m512 vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i0[i*16])); + __m512 vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i1[i*16])); + __m512 vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i2[i*16])); + __m512 vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i3[i*16])); + __m512 vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i4[i*16])); + __m512 vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i5[i*16])); + __m512 vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i6[i*16])); + vacc[i] = _mm512_fmadd_ps(vin0, vin0, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin1, vin1, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin2, vin2, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin3, vin3, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin4, vin4, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin5, vin5, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin6, vin6, vacc[i]); + } + + if (remainder) { + __m512 vin0 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i0[num_full_chunks*16])); + __m512 vin1 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i1[num_full_chunks*16])); + __m512 vin2 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i2[num_full_chunks*16])); + __m512 vin3 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i3[num_full_chunks*16])); + __m512 vin4 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i4[num_full_chunks*16])); + __m512 vin5 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i5[num_full_chunks*16])); + __m512 vin6 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i6[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin0, vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin1, vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin2, vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin3, vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin4, vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin5, vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin6, vin6, vacc[num_full_chunks]); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[1]; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm512_loadu_ps(output + i * 16); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm512_storeu_ps(output, vacc[i]); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + } + if (remainder) { + __m512 vout = vacc[num_full_chunks]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } + } + } +} + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__avx512skx_u32( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 32; channels -= 32) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + __m512 vin0; + __m512 vin1; + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[16]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[16]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[16]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[16]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[16]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[16]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[16]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm512_mul_ps(vacc0, vscale); + vacc1 = _mm512_mul_ps(vacc1, vscale); + + __m512 vo0 = _mm512_loadu_ps(output + 0 * 16); + __m512 vo1 = _mm512_loadu_ps(output + 1 * 16); + vacc0 = _mm512_add_ps(vo0, vacc0); + vacc1 = _mm512_add_ps(vo1, vacc1); + _mm512_storeu_ps(output, vacc0); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc1); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 32 * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + __m512 vacc[2]; + vacc[0] = _mm512_setzero_ps(); + vacc[1] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + // AVX512 has 16 float lanes. + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + // 0xF masks the remainder. + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask; + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + __m512 vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i0[i*16])); + __m512 vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i1[i*16])); + __m512 vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i2[i*16])); + __m512 vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i3[i*16])); + __m512 vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i4[i*16])); + __m512 vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i5[i*16])); + __m512 vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i6[i*16])); + vacc[i] = _mm512_fmadd_ps(vin0, vin0, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin1, vin1, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin2, vin2, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin3, vin3, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin4, vin4, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin5, vin5, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin6, vin6, vacc[i]); + } + + if (remainder) { + __m512 vin0 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i0[num_full_chunks*16])); + __m512 vin1 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i1[num_full_chunks*16])); + __m512 vin2 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i2[num_full_chunks*16])); + __m512 vin3 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i3[num_full_chunks*16])); + __m512 vin4 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i4[num_full_chunks*16])); + __m512 vin5 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i5[num_full_chunks*16])); + __m512 vin6 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i6[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin0, vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin1, vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin2, vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin3, vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin4, vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin5, vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin6, vin6, vacc[num_full_chunks]); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[2]; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm512_loadu_ps(output + i * 16); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm512_storeu_ps(output, vacc[i]); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + } + if (remainder) { + __m512 vout = vacc[num_full_chunks]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } + } + } +} + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__avx512skx_u64( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 64; channels -= 64) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + __m512 vacc2 = _mm512_setzero_ps(); + __m512 vacc3 = _mm512_setzero_ps(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + __m512 vin0; + __m512 vin1; + __m512 vin2; + __m512 vin3; + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[48]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[48]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[48]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[48]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[48]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[48]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[48]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm512_mul_ps(vacc0, vscale); + vacc1 = _mm512_mul_ps(vacc1, vscale); + vacc2 = _mm512_mul_ps(vacc2, vscale); + vacc3 = _mm512_mul_ps(vacc3, vscale); + + __m512 vo0 = _mm512_loadu_ps(output + 0 * 16); + __m512 vo1 = _mm512_loadu_ps(output + 1 * 16); + __m512 vo2 = _mm512_loadu_ps(output + 2 * 16); + __m512 vo3 = _mm512_loadu_ps(output + 3 * 16); + vacc0 = _mm512_add_ps(vo0, vacc0); + vacc1 = _mm512_add_ps(vo1, vacc1); + vacc2 = _mm512_add_ps(vo2, vacc2); + vacc3 = _mm512_add_ps(vo3, vacc3); + _mm512_storeu_ps(output, vacc0); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc1); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc2); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc3); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 64 * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + __m512 vacc[4]; + vacc[0] = _mm512_setzero_ps(); + vacc[1] = _mm512_setzero_ps(); + vacc[2] = _mm512_setzero_ps(); + vacc[3] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + // AVX512 has 16 float lanes. + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + // 0xF masks the remainder. + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask; + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + __m512 vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i0[i*16])); + __m512 vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i1[i*16])); + __m512 vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i2[i*16])); + __m512 vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i3[i*16])); + __m512 vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i4[i*16])); + __m512 vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i5[i*16])); + __m512 vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i6[i*16])); + vacc[i] = _mm512_fmadd_ps(vin0, vin0, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin1, vin1, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin2, vin2, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin3, vin3, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin4, vin4, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin5, vin5, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin6, vin6, vacc[i]); + } + + if (remainder) { + __m512 vin0 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i0[num_full_chunks*16])); + __m512 vin1 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i1[num_full_chunks*16])); + __m512 vin2 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i2[num_full_chunks*16])); + __m512 vin3 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i3[num_full_chunks*16])); + __m512 vin4 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i4[num_full_chunks*16])); + __m512 vin5 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i5[num_full_chunks*16])); + __m512 vin6 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i6[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin0, vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin1, vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin2, vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin3, vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin4, vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin5, vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin6, vin6, vacc[num_full_chunks]); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[4]; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm512_loadu_ps(output + i * 16); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm512_storeu_ps(output, vacc[i]); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + } + if (remainder) { + __m512 vout = vacc[num_full_chunks]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } + } + } +} + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__avx512skx_u128( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 128; channels -= 128) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + __m512 vacc2 = _mm512_setzero_ps(); + __m512 vacc3 = _mm512_setzero_ps(); + __m512 vacc4 = _mm512_setzero_ps(); + __m512 vacc5 = _mm512_setzero_ps(); + __m512 vacc6 = _mm512_setzero_ps(); + __m512 vacc7 = _mm512_setzero_ps(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + __m512 vin0; + __m512 vin1; + __m512 vin2; + __m512 vin3; + __m512 vin4; + __m512 vin5; + __m512 vin6; + __m512 vin7; + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[48]))); + vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[64]))); + vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[80]))); + vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[96]))); + vin7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[112]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vacc4 = _mm512_fmadd_ps(vin4, vin4, vacc4); + vacc5 = _mm512_fmadd_ps(vin5, vin5, vacc5); + vacc6 = _mm512_fmadd_ps(vin6, vin6, vacc6); + vacc7 = _mm512_fmadd_ps(vin7, vin7, vacc7); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[48]))); + vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[64]))); + vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[80]))); + vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[96]))); + vin7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[112]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vacc4 = _mm512_fmadd_ps(vin4, vin4, vacc4); + vacc5 = _mm512_fmadd_ps(vin5, vin5, vacc5); + vacc6 = _mm512_fmadd_ps(vin6, vin6, vacc6); + vacc7 = _mm512_fmadd_ps(vin7, vin7, vacc7); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[48]))); + vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[64]))); + vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[80]))); + vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[96]))); + vin7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[112]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vacc4 = _mm512_fmadd_ps(vin4, vin4, vacc4); + vacc5 = _mm512_fmadd_ps(vin5, vin5, vacc5); + vacc6 = _mm512_fmadd_ps(vin6, vin6, vacc6); + vacc7 = _mm512_fmadd_ps(vin7, vin7, vacc7); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[48]))); + vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[64]))); + vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[80]))); + vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[96]))); + vin7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[112]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vacc4 = _mm512_fmadd_ps(vin4, vin4, vacc4); + vacc5 = _mm512_fmadd_ps(vin5, vin5, vacc5); + vacc6 = _mm512_fmadd_ps(vin6, vin6, vacc6); + vacc7 = _mm512_fmadd_ps(vin7, vin7, vacc7); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[48]))); + vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[64]))); + vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[80]))); + vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[96]))); + vin7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[112]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vacc4 = _mm512_fmadd_ps(vin4, vin4, vacc4); + vacc5 = _mm512_fmadd_ps(vin5, vin5, vacc5); + vacc6 = _mm512_fmadd_ps(vin6, vin6, vacc6); + vacc7 = _mm512_fmadd_ps(vin7, vin7, vacc7); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[48]))); + vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[64]))); + vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[80]))); + vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[96]))); + vin7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[112]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vacc4 = _mm512_fmadd_ps(vin4, vin4, vacc4); + vacc5 = _mm512_fmadd_ps(vin5, vin5, vacc5); + vacc6 = _mm512_fmadd_ps(vin6, vin6, vacc6); + vacc7 = _mm512_fmadd_ps(vin7, vin7, vacc7); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[48]))); + vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[64]))); + vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[80]))); + vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[96]))); + vin7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[112]))); + vacc0 = _mm512_fmadd_ps(vin0, vin0, vacc0); + vacc1 = _mm512_fmadd_ps(vin1, vin1, vacc1); + vacc2 = _mm512_fmadd_ps(vin2, vin2, vacc2); + vacc3 = _mm512_fmadd_ps(vin3, vin3, vacc3); + vacc4 = _mm512_fmadd_ps(vin4, vin4, vacc4); + vacc5 = _mm512_fmadd_ps(vin5, vin5, vacc5); + vacc6 = _mm512_fmadd_ps(vin6, vin6, vacc6); + vacc7 = _mm512_fmadd_ps(vin7, vin7, vacc7); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm512_mul_ps(vacc0, vscale); + vacc1 = _mm512_mul_ps(vacc1, vscale); + vacc2 = _mm512_mul_ps(vacc2, vscale); + vacc3 = _mm512_mul_ps(vacc3, vscale); + vacc4 = _mm512_mul_ps(vacc4, vscale); + vacc5 = _mm512_mul_ps(vacc5, vscale); + vacc6 = _mm512_mul_ps(vacc6, vscale); + vacc7 = _mm512_mul_ps(vacc7, vscale); + + __m512 vo0 = _mm512_loadu_ps(output + 0 * 16); + __m512 vo1 = _mm512_loadu_ps(output + 1 * 16); + __m512 vo2 = _mm512_loadu_ps(output + 2 * 16); + __m512 vo3 = _mm512_loadu_ps(output + 3 * 16); + __m512 vo4 = _mm512_loadu_ps(output + 4 * 16); + __m512 vo5 = _mm512_loadu_ps(output + 5 * 16); + __m512 vo6 = _mm512_loadu_ps(output + 6 * 16); + __m512 vo7 = _mm512_loadu_ps(output + 7 * 16); + vacc0 = _mm512_add_ps(vo0, vacc0); + vacc1 = _mm512_add_ps(vo1, vacc1); + vacc2 = _mm512_add_ps(vo2, vacc2); + vacc3 = _mm512_add_ps(vo3, vacc3); + vacc4 = _mm512_add_ps(vo4, vacc4); + vacc5 = _mm512_add_ps(vo5, vacc5); + vacc6 = _mm512_add_ps(vo6, vacc6); + vacc7 = _mm512_add_ps(vo7, vacc7); + _mm512_storeu_ps(output, vacc0); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc1); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc2); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc3); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc4); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc5); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc6); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + _mm512_storeu_ps(output, vacc7); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 128 * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + __m512 vacc[8]; + vacc[0] = _mm512_setzero_ps(); + vacc[1] = _mm512_setzero_ps(); + vacc[2] = _mm512_setzero_ps(); + vacc[3] = _mm512_setzero_ps(); + vacc[4] = _mm512_setzero_ps(); + vacc[5] = _mm512_setzero_ps(); + vacc[6] = _mm512_setzero_ps(); + vacc[7] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + // AVX512 has 16 float lanes. + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + // 0xF masks the remainder. + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask; + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + __m512 vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i0[i*16])); + __m512 vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i1[i*16])); + __m512 vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i2[i*16])); + __m512 vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i3[i*16])); + __m512 vin4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i4[i*16])); + __m512 vin5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i5[i*16])); + __m512 vin6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i6[i*16])); + vacc[i] = _mm512_fmadd_ps(vin0, vin0, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin1, vin1, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin2, vin2, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin3, vin3, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin4, vin4, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin5, vin5, vacc[i]); + vacc[i] = _mm512_fmadd_ps(vin6, vin6, vacc[i]); + } + + if (remainder) { + __m512 vin0 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i0[num_full_chunks*16])); + __m512 vin1 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i1[num_full_chunks*16])); + __m512 vin2 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i2[num_full_chunks*16])); + __m512 vin3 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i3[num_full_chunks*16])); + __m512 vin4 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i4[num_full_chunks*16])); + __m512 vin5 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i5[num_full_chunks*16])); + __m512 vin6 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i6[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin0, vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin1, vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin2, vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin3, vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin4, vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin5, vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm512_maskz_fmadd_ps(vmask, vin6, vin6, vacc[num_full_chunks]); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[8]; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm512_loadu_ps(output + i * 16); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm512_storeu_ps(output, vacc[i]); output = (void*) ((uintptr_t) output + 16 * sizeof(float)); + } + if (remainder) { + __m512 vout = vacc[num_full_chunks]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } + } + } +} diff --git a/src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-f16c.c b/src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-f16c.c new file mode 100644 index 00000000000..fdaa869e6fa --- /dev/null +++ b/src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-f16c.c @@ -0,0 +1,1678 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f16-f32acc-rdsum2/avx.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include + +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/unaligned.h" +#include "src/xnnpack/common.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/math.h" + + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__f16c_u16( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 16; channels -= 16) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + __m256 vin0; + __m256 vin1; + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[8]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[8]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[8]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[8]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[8]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[8]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[8]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo1 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + _mm256_storeu_ps(output, vacc0); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc1); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 16 * sizeof(xnn_float16)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + __m256 vacc[2]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + __m256 vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i0[i*8])); + __m256 vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i1[i*8])); + __m256 vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i2[i*8])); + __m256 vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i3[i*8])); + __m256 vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i4[i*8])); + __m256 vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i5[i*8])); + __m256 vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i6[i*8])); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vacc[i] = _mm256_add_ps(vin0, vacc[i]); + vacc[i] = _mm256_add_ps(vin1, vacc[i]); + vacc[i] = _mm256_add_ps(vin2, vacc[i]); + vacc[i] = _mm256_add_ps(vin3, vacc[i]); + vacc[i] = _mm256_add_ps(vin4, vacc[i]); + vacc[i] = _mm256_add_ps(vin5, vacc[i]); + vacc[i] = _mm256_add_ps(vin6, vacc[i]); + } + + if (remainder) { + __m256 vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i0[num_full_chunks*8])); + __m256 vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i1[num_full_chunks*8])); + __m256 vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i2[num_full_chunks*8])); + __m256 vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i3[num_full_chunks*8])); + __m256 vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i4[num_full_chunks*8])); + __m256 vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i5[num_full_chunks*8])); + __m256 vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i6[num_full_chunks*8])); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin0); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin1); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin2); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin3); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin4); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin5); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin6); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[2]; + const float* o = output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm256_storeu_ps(output, vacc[i]); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + } + if (remainder) { + __m256 vout = vacc[num_full_chunks]; + __m128 vout_low = _mm256_castps256_ps128(vout); + if (channels & 4) { + __m128 vo = _mm_loadu_ps(output); + vo = _mm_add_ps(vout_low, vo); + _mm_storeu_ps(output, vo); + vout_low = _mm256_castps256_ps128(_mm256_permute2f128_ps(vout, vout, 1)); + output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + } + if (channels & 2) { + __m128 vo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i*) output)); + vo = _mm_add_ps(vout_low, vo); + _mm_storel_pi((__m64*) output, vo); + vout_low = _mm_movehl_ps(vout_low, vout_low); + output = (void*) ((uintptr_t) output + 2 * sizeof(float)); + } + if (channels & 1) { + __m128 vo = _mm_castsi128_ps(_mm_cvtsi32_si128(unaligned_load_s32(output))); + vo = _mm_add_ps(vout_low, vo); + _mm_store_ss(output, vo); + } + } + } + } + } +} + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__f16c_u32( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 32; channels -= 32) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + __m256 vacc3 = _mm256_setzero_ps(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + __m256 vin0; + __m256 vin1; + __m256 vin2; + __m256 vin3; + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[24]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[24]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[24]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[24]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[24]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[24]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[24]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + vacc2 = _mm256_mul_ps(vacc2, vscale); + vacc3 = _mm256_mul_ps(vacc3, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo1 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo2 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo3 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + vacc2 = _mm256_add_ps(vo2, vacc2); + vacc3 = _mm256_add_ps(vo3, vacc3); + _mm256_storeu_ps(output, vacc0); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc1); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc2); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc3); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 32 * sizeof(xnn_float16)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + __m256 vacc[4]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + vacc[2] = _mm256_setzero_ps(); + vacc[3] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + __m256 vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i0[i*8])); + __m256 vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i1[i*8])); + __m256 vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i2[i*8])); + __m256 vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i3[i*8])); + __m256 vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i4[i*8])); + __m256 vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i5[i*8])); + __m256 vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i6[i*8])); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vacc[i] = _mm256_add_ps(vin0, vacc[i]); + vacc[i] = _mm256_add_ps(vin1, vacc[i]); + vacc[i] = _mm256_add_ps(vin2, vacc[i]); + vacc[i] = _mm256_add_ps(vin3, vacc[i]); + vacc[i] = _mm256_add_ps(vin4, vacc[i]); + vacc[i] = _mm256_add_ps(vin5, vacc[i]); + vacc[i] = _mm256_add_ps(vin6, vacc[i]); + } + + if (remainder) { + __m256 vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i0[num_full_chunks*8])); + __m256 vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i1[num_full_chunks*8])); + __m256 vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i2[num_full_chunks*8])); + __m256 vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i3[num_full_chunks*8])); + __m256 vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i4[num_full_chunks*8])); + __m256 vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i5[num_full_chunks*8])); + __m256 vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i6[num_full_chunks*8])); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin0); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin1); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin2); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin3); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin4); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin5); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin6); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[4]; + const float* o = output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm256_storeu_ps(output, vacc[i]); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + } + if (remainder) { + __m256 vout = vacc[num_full_chunks]; + __m128 vout_low = _mm256_castps256_ps128(vout); + if (channels & 4) { + __m128 vo = _mm_loadu_ps(output); + vo = _mm_add_ps(vout_low, vo); + _mm_storeu_ps(output, vo); + vout_low = _mm256_castps256_ps128(_mm256_permute2f128_ps(vout, vout, 1)); + output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + } + if (channels & 2) { + __m128 vo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i*) output)); + vo = _mm_add_ps(vout_low, vo); + _mm_storel_pi((__m64*) output, vo); + vout_low = _mm_movehl_ps(vout_low, vout_low); + output = (void*) ((uintptr_t) output + 2 * sizeof(float)); + } + if (channels & 1) { + __m128 vo = _mm_castsi128_ps(_mm_cvtsi32_si128(unaligned_load_s32(output))); + vo = _mm_add_ps(vout_low, vo); + _mm_store_ss(output, vo); + } + } + } + } + } +} + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__f16c_u64( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 64; channels -= 64) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + __m256 vacc3 = _mm256_setzero_ps(); + __m256 vacc4 = _mm256_setzero_ps(); + __m256 vacc5 = _mm256_setzero_ps(); + __m256 vacc6 = _mm256_setzero_ps(); + __m256 vacc7 = _mm256_setzero_ps(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + __m256 vin0; + __m256 vin1; + __m256 vin2; + __m256 vin3; + __m256 vin4; + __m256 vin5; + __m256 vin6; + __m256 vin7; + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[56]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[56]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[56]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[56]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[56]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[56]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[56]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + vacc2 = _mm256_mul_ps(vacc2, vscale); + vacc3 = _mm256_mul_ps(vacc3, vscale); + vacc4 = _mm256_mul_ps(vacc4, vscale); + vacc5 = _mm256_mul_ps(vacc5, vscale); + vacc6 = _mm256_mul_ps(vacc6, vscale); + vacc7 = _mm256_mul_ps(vacc7, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo1 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo2 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo3 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo4 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo5 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo6 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo7 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + vacc2 = _mm256_add_ps(vo2, vacc2); + vacc3 = _mm256_add_ps(vo3, vacc3); + vacc4 = _mm256_add_ps(vo4, vacc4); + vacc5 = _mm256_add_ps(vo5, vacc5); + vacc6 = _mm256_add_ps(vo6, vacc6); + vacc7 = _mm256_add_ps(vo7, vacc7); + _mm256_storeu_ps(output, vacc0); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc1); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc2); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc3); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc4); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc5); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc6); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc7); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 64 * sizeof(xnn_float16)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + __m256 vacc[8]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + vacc[2] = _mm256_setzero_ps(); + vacc[3] = _mm256_setzero_ps(); + vacc[4] = _mm256_setzero_ps(); + vacc[5] = _mm256_setzero_ps(); + vacc[6] = _mm256_setzero_ps(); + vacc[7] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + __m256 vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i0[i*8])); + __m256 vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i1[i*8])); + __m256 vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i2[i*8])); + __m256 vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i3[i*8])); + __m256 vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i4[i*8])); + __m256 vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i5[i*8])); + __m256 vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i6[i*8])); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vacc[i] = _mm256_add_ps(vin0, vacc[i]); + vacc[i] = _mm256_add_ps(vin1, vacc[i]); + vacc[i] = _mm256_add_ps(vin2, vacc[i]); + vacc[i] = _mm256_add_ps(vin3, vacc[i]); + vacc[i] = _mm256_add_ps(vin4, vacc[i]); + vacc[i] = _mm256_add_ps(vin5, vacc[i]); + vacc[i] = _mm256_add_ps(vin6, vacc[i]); + } + + if (remainder) { + __m256 vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i0[num_full_chunks*8])); + __m256 vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i1[num_full_chunks*8])); + __m256 vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i2[num_full_chunks*8])); + __m256 vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i3[num_full_chunks*8])); + __m256 vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i4[num_full_chunks*8])); + __m256 vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i5[num_full_chunks*8])); + __m256 vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i6[num_full_chunks*8])); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin0); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin1); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin2); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin3); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin4); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin5); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin6); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[8]; + const float* o = output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm256_storeu_ps(output, vacc[i]); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + } + if (remainder) { + __m256 vout = vacc[num_full_chunks]; + __m128 vout_low = _mm256_castps256_ps128(vout); + if (channels & 4) { + __m128 vo = _mm_loadu_ps(output); + vo = _mm_add_ps(vout_low, vo); + _mm_storeu_ps(output, vo); + vout_low = _mm256_castps256_ps128(_mm256_permute2f128_ps(vout, vout, 1)); + output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + } + if (channels & 2) { + __m128 vo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i*) output)); + vo = _mm_add_ps(vout_low, vo); + _mm_storel_pi((__m64*) output, vo); + vout_low = _mm_movehl_ps(vout_low, vout_low); + output = (void*) ((uintptr_t) output + 2 * sizeof(float)); + } + if (channels & 1) { + __m128 vo = _mm_castsi128_ps(_mm_cvtsi32_si128(unaligned_load_s32(output))); + vo = _mm_add_ps(vout_low, vo); + _mm_store_ss(output, vo); + } + } + } + } + } +} + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__f16c_u128( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 128; channels -= 128) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + __m256 vacc3 = _mm256_setzero_ps(); + __m256 vacc4 = _mm256_setzero_ps(); + __m256 vacc5 = _mm256_setzero_ps(); + __m256 vacc6 = _mm256_setzero_ps(); + __m256 vacc7 = _mm256_setzero_ps(); + __m256 vacc8 = _mm256_setzero_ps(); + __m256 vacc9 = _mm256_setzero_ps(); + __m256 vacc10 = _mm256_setzero_ps(); + __m256 vacc11 = _mm256_setzero_ps(); + __m256 vacc12 = _mm256_setzero_ps(); + __m256 vacc13 = _mm256_setzero_ps(); + __m256 vacc14 = _mm256_setzero_ps(); + __m256 vacc15 = _mm256_setzero_ps(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + __m256 vin0; + __m256 vin1; + __m256 vin2; + __m256 vin3; + __m256 vin4; + __m256 vin5; + __m256 vin6; + __m256 vin7; + __m256 vin8; + __m256 vin9; + __m256 vin10; + __m256 vin11; + __m256 vin12; + __m256 vin13; + __m256 vin14; + __m256 vin15; + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[56]))); + vin8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[64]))); + vin9 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[72]))); + vin10 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[80]))); + vin11 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[88]))); + vin12 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[96]))); + vin13 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[104]))); + vin14 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[112]))); + vin15 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[120]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vin8 = _mm256_mul_ps(vin8, vin8); + vin9 = _mm256_mul_ps(vin9, vin9); + vin10 = _mm256_mul_ps(vin10, vin10); + vin11 = _mm256_mul_ps(vin11, vin11); + vin12 = _mm256_mul_ps(vin12, vin12); + vin13 = _mm256_mul_ps(vin13, vin13); + vin14 = _mm256_mul_ps(vin14, vin14); + vin15 = _mm256_mul_ps(vin15, vin15); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vacc8 = _mm256_add_ps(vin8, vacc8); + vacc9 = _mm256_add_ps(vin9, vacc9); + vacc10 = _mm256_add_ps(vin10, vacc10); + vacc11 = _mm256_add_ps(vin11, vacc11); + vacc12 = _mm256_add_ps(vin12, vacc12); + vacc13 = _mm256_add_ps(vin13, vacc13); + vacc14 = _mm256_add_ps(vin14, vacc14); + vacc15 = _mm256_add_ps(vin15, vacc15); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[56]))); + vin8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[64]))); + vin9 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[72]))); + vin10 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[80]))); + vin11 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[88]))); + vin12 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[96]))); + vin13 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[104]))); + vin14 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[112]))); + vin15 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[120]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vin8 = _mm256_mul_ps(vin8, vin8); + vin9 = _mm256_mul_ps(vin9, vin9); + vin10 = _mm256_mul_ps(vin10, vin10); + vin11 = _mm256_mul_ps(vin11, vin11); + vin12 = _mm256_mul_ps(vin12, vin12); + vin13 = _mm256_mul_ps(vin13, vin13); + vin14 = _mm256_mul_ps(vin14, vin14); + vin15 = _mm256_mul_ps(vin15, vin15); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vacc8 = _mm256_add_ps(vin8, vacc8); + vacc9 = _mm256_add_ps(vin9, vacc9); + vacc10 = _mm256_add_ps(vin10, vacc10); + vacc11 = _mm256_add_ps(vin11, vacc11); + vacc12 = _mm256_add_ps(vin12, vacc12); + vacc13 = _mm256_add_ps(vin13, vacc13); + vacc14 = _mm256_add_ps(vin14, vacc14); + vacc15 = _mm256_add_ps(vin15, vacc15); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[56]))); + vin8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[64]))); + vin9 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[72]))); + vin10 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[80]))); + vin11 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[88]))); + vin12 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[96]))); + vin13 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[104]))); + vin14 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[112]))); + vin15 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[120]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vin8 = _mm256_mul_ps(vin8, vin8); + vin9 = _mm256_mul_ps(vin9, vin9); + vin10 = _mm256_mul_ps(vin10, vin10); + vin11 = _mm256_mul_ps(vin11, vin11); + vin12 = _mm256_mul_ps(vin12, vin12); + vin13 = _mm256_mul_ps(vin13, vin13); + vin14 = _mm256_mul_ps(vin14, vin14); + vin15 = _mm256_mul_ps(vin15, vin15); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vacc8 = _mm256_add_ps(vin8, vacc8); + vacc9 = _mm256_add_ps(vin9, vacc9); + vacc10 = _mm256_add_ps(vin10, vacc10); + vacc11 = _mm256_add_ps(vin11, vacc11); + vacc12 = _mm256_add_ps(vin12, vacc12); + vacc13 = _mm256_add_ps(vin13, vacc13); + vacc14 = _mm256_add_ps(vin14, vacc14); + vacc15 = _mm256_add_ps(vin15, vacc15); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[56]))); + vin8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[64]))); + vin9 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[72]))); + vin10 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[80]))); + vin11 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[88]))); + vin12 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[96]))); + vin13 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[104]))); + vin14 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[112]))); + vin15 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[120]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vin8 = _mm256_mul_ps(vin8, vin8); + vin9 = _mm256_mul_ps(vin9, vin9); + vin10 = _mm256_mul_ps(vin10, vin10); + vin11 = _mm256_mul_ps(vin11, vin11); + vin12 = _mm256_mul_ps(vin12, vin12); + vin13 = _mm256_mul_ps(vin13, vin13); + vin14 = _mm256_mul_ps(vin14, vin14); + vin15 = _mm256_mul_ps(vin15, vin15); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vacc8 = _mm256_add_ps(vin8, vacc8); + vacc9 = _mm256_add_ps(vin9, vacc9); + vacc10 = _mm256_add_ps(vin10, vacc10); + vacc11 = _mm256_add_ps(vin11, vacc11); + vacc12 = _mm256_add_ps(vin12, vacc12); + vacc13 = _mm256_add_ps(vin13, vacc13); + vacc14 = _mm256_add_ps(vin14, vacc14); + vacc15 = _mm256_add_ps(vin15, vacc15); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[56]))); + vin8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[64]))); + vin9 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[72]))); + vin10 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[80]))); + vin11 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[88]))); + vin12 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[96]))); + vin13 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[104]))); + vin14 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[112]))); + vin15 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[120]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vin8 = _mm256_mul_ps(vin8, vin8); + vin9 = _mm256_mul_ps(vin9, vin9); + vin10 = _mm256_mul_ps(vin10, vin10); + vin11 = _mm256_mul_ps(vin11, vin11); + vin12 = _mm256_mul_ps(vin12, vin12); + vin13 = _mm256_mul_ps(vin13, vin13); + vin14 = _mm256_mul_ps(vin14, vin14); + vin15 = _mm256_mul_ps(vin15, vin15); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vacc8 = _mm256_add_ps(vin8, vacc8); + vacc9 = _mm256_add_ps(vin9, vacc9); + vacc10 = _mm256_add_ps(vin10, vacc10); + vacc11 = _mm256_add_ps(vin11, vacc11); + vacc12 = _mm256_add_ps(vin12, vacc12); + vacc13 = _mm256_add_ps(vin13, vacc13); + vacc14 = _mm256_add_ps(vin14, vacc14); + vacc15 = _mm256_add_ps(vin15, vacc15); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[56]))); + vin8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[64]))); + vin9 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[72]))); + vin10 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[80]))); + vin11 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[88]))); + vin12 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[96]))); + vin13 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[104]))); + vin14 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[112]))); + vin15 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[120]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vin8 = _mm256_mul_ps(vin8, vin8); + vin9 = _mm256_mul_ps(vin9, vin9); + vin10 = _mm256_mul_ps(vin10, vin10); + vin11 = _mm256_mul_ps(vin11, vin11); + vin12 = _mm256_mul_ps(vin12, vin12); + vin13 = _mm256_mul_ps(vin13, vin13); + vin14 = _mm256_mul_ps(vin14, vin14); + vin15 = _mm256_mul_ps(vin15, vin15); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vacc8 = _mm256_add_ps(vin8, vacc8); + vacc9 = _mm256_add_ps(vin9, vacc9); + vacc10 = _mm256_add_ps(vin10, vacc10); + vacc11 = _mm256_add_ps(vin11, vacc11); + vacc12 = _mm256_add_ps(vin12, vacc12); + vacc13 = _mm256_add_ps(vin13, vacc13); + vacc14 = _mm256_add_ps(vin14, vacc14); + vacc15 = _mm256_add_ps(vin15, vacc15); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[24]))); + vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[32]))); + vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[40]))); + vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[48]))); + vin7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[56]))); + vin8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[64]))); + vin9 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[72]))); + vin10 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[80]))); + vin11 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[88]))); + vin12 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[96]))); + vin13 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[104]))); + vin14 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[112]))); + vin15 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[120]))); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vin7 = _mm256_mul_ps(vin7, vin7); + vin8 = _mm256_mul_ps(vin8, vin8); + vin9 = _mm256_mul_ps(vin9, vin9); + vin10 = _mm256_mul_ps(vin10, vin10); + vin11 = _mm256_mul_ps(vin11, vin11); + vin12 = _mm256_mul_ps(vin12, vin12); + vin13 = _mm256_mul_ps(vin13, vin13); + vin14 = _mm256_mul_ps(vin14, vin14); + vin15 = _mm256_mul_ps(vin15, vin15); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vacc4 = _mm256_add_ps(vin4, vacc4); + vacc5 = _mm256_add_ps(vin5, vacc5); + vacc6 = _mm256_add_ps(vin6, vacc6); + vacc7 = _mm256_add_ps(vin7, vacc7); + vacc8 = _mm256_add_ps(vin8, vacc8); + vacc9 = _mm256_add_ps(vin9, vacc9); + vacc10 = _mm256_add_ps(vin10, vacc10); + vacc11 = _mm256_add_ps(vin11, vacc11); + vacc12 = _mm256_add_ps(vin12, vacc12); + vacc13 = _mm256_add_ps(vin13, vacc13); + vacc14 = _mm256_add_ps(vin14, vacc14); + vacc15 = _mm256_add_ps(vin15, vacc15); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + vacc2 = _mm256_mul_ps(vacc2, vscale); + vacc3 = _mm256_mul_ps(vacc3, vscale); + vacc4 = _mm256_mul_ps(vacc4, vscale); + vacc5 = _mm256_mul_ps(vacc5, vscale); + vacc6 = _mm256_mul_ps(vacc6, vscale); + vacc7 = _mm256_mul_ps(vacc7, vscale); + vacc8 = _mm256_mul_ps(vacc8, vscale); + vacc9 = _mm256_mul_ps(vacc9, vscale); + vacc10 = _mm256_mul_ps(vacc10, vscale); + vacc11 = _mm256_mul_ps(vacc11, vscale); + vacc12 = _mm256_mul_ps(vacc12, vscale); + vacc13 = _mm256_mul_ps(vacc13, vscale); + vacc14 = _mm256_mul_ps(vacc14, vscale); + vacc15 = _mm256_mul_ps(vacc15, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo1 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo2 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo3 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo4 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo5 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo6 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo7 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo8 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo9 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo10 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo11 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo12 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo13 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo14 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + __m256 vo15 = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + vacc2 = _mm256_add_ps(vo2, vacc2); + vacc3 = _mm256_add_ps(vo3, vacc3); + vacc4 = _mm256_add_ps(vo4, vacc4); + vacc5 = _mm256_add_ps(vo5, vacc5); + vacc6 = _mm256_add_ps(vo6, vacc6); + vacc7 = _mm256_add_ps(vo7, vacc7); + vacc8 = _mm256_add_ps(vo8, vacc8); + vacc9 = _mm256_add_ps(vo9, vacc9); + vacc10 = _mm256_add_ps(vo10, vacc10); + vacc11 = _mm256_add_ps(vo11, vacc11); + vacc12 = _mm256_add_ps(vo12, vacc12); + vacc13 = _mm256_add_ps(vo13, vacc13); + vacc14 = _mm256_add_ps(vo14, vacc14); + vacc15 = _mm256_add_ps(vo15, vacc15); + _mm256_storeu_ps(output, vacc0); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc1); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc2); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc3); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc4); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc5); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc6); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc7); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc8); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc9); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc10); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc11); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc12); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc13); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc14); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + _mm256_storeu_ps(output, vacc15); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 128 * sizeof(xnn_float16)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + __m256 vacc[16]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + vacc[2] = _mm256_setzero_ps(); + vacc[3] = _mm256_setzero_ps(); + vacc[4] = _mm256_setzero_ps(); + vacc[5] = _mm256_setzero_ps(); + vacc[6] = _mm256_setzero_ps(); + vacc[7] = _mm256_setzero_ps(); + vacc[8] = _mm256_setzero_ps(); + vacc[9] = _mm256_setzero_ps(); + vacc[10] = _mm256_setzero_ps(); + vacc[11] = _mm256_setzero_ps(); + vacc[12] = _mm256_setzero_ps(); + vacc[13] = _mm256_setzero_ps(); + vacc[14] = _mm256_setzero_ps(); + vacc[15] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + __m256 vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i0[i*8])); + __m256 vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i1[i*8])); + __m256 vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i2[i*8])); + __m256 vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i3[i*8])); + __m256 vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i4[i*8])); + __m256 vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i5[i*8])); + __m256 vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i6[i*8])); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vacc[i] = _mm256_add_ps(vin0, vacc[i]); + vacc[i] = _mm256_add_ps(vin1, vacc[i]); + vacc[i] = _mm256_add_ps(vin2, vacc[i]); + vacc[i] = _mm256_add_ps(vin3, vacc[i]); + vacc[i] = _mm256_add_ps(vin4, vacc[i]); + vacc[i] = _mm256_add_ps(vin5, vacc[i]); + vacc[i] = _mm256_add_ps(vin6, vacc[i]); + } + + if (remainder) { + __m256 vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i0[num_full_chunks*8])); + __m256 vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i1[num_full_chunks*8])); + __m256 vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i2[num_full_chunks*8])); + __m256 vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i3[num_full_chunks*8])); + __m256 vin4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i4[num_full_chunks*8])); + __m256 vin5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i5[num_full_chunks*8])); + __m256 vin6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i6[num_full_chunks*8])); + vin0 = _mm256_mul_ps(vin0, vin0); + vin1 = _mm256_mul_ps(vin1, vin1); + vin2 = _mm256_mul_ps(vin2, vin2); + vin3 = _mm256_mul_ps(vin3, vin3); + vin4 = _mm256_mul_ps(vin4, vin4); + vin5 = _mm256_mul_ps(vin5, vin5); + vin6 = _mm256_mul_ps(vin6, vin6); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin0); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin1); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin2); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin3); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin4); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin5); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], vin6); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[16]; + const float* o = output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm256_loadu_ps(o); o = (const void*) ((uintptr_t) o + 8 * sizeof(float)); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm256_storeu_ps(output, vacc[i]); output = (void*) ((uintptr_t) output + 8 * sizeof(float)); + } + if (remainder) { + __m256 vout = vacc[num_full_chunks]; + __m128 vout_low = _mm256_castps256_ps128(vout); + if (channels & 4) { + __m128 vo = _mm_loadu_ps(output); + vo = _mm_add_ps(vout_low, vo); + _mm_storeu_ps(output, vo); + vout_low = _mm256_castps256_ps128(_mm256_permute2f128_ps(vout, vout, 1)); + output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + } + if (channels & 2) { + __m128 vo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i*) output)); + vo = _mm_add_ps(vout_low, vo); + _mm_storel_pi((__m64*) output, vo); + vout_low = _mm_movehl_ps(vout_low, vout_low); + output = (void*) ((uintptr_t) output + 2 * sizeof(float)); + } + if (channels & 1) { + __m128 vo = _mm_castsi128_ps(_mm_cvtsi32_si128(unaligned_load_s32(output))); + vo = _mm_add_ps(vout_low, vo); + _mm_store_ss(output, vo); + } + } + } + } + } +} diff --git a/src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-minmax-neonfp16arith.c b/src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-minmax-neonfp16arith.c new file mode 100644 index 00000000000..43a6265cfa6 --- /dev/null +++ b/src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-minmax-neonfp16arith.c @@ -0,0 +1,1091 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f16-f32acc-rdsum2/neon.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" + + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__neonfp16arith_u16( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const float32x4_t vscale = vdupq_n_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 16; channels -= 16) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + float32x4_t vacc0 = vdupq_n_f32(0.f); + float32x4_t vacc1 = vdupq_n_f32(0.f); + float32x4_t vacc2 = vdupq_n_f32(0.f); + float32x4_t vacc3 = vdupq_n_f32(0.f); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + float32x4_t vin0; + float32x4_t vin1; + float32x4_t vin2; + float32x4_t vin3; + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[12]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[12]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[12]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[12]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[12]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[12]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[12]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = vmulq_f32(vacc0, vscale); + vacc1 = vmulq_f32(vacc1, vscale); + vacc2 = vmulq_f32(vacc2, vscale); + vacc3 = vmulq_f32(vacc3, vscale); + + const float* o = (const float*) output; + float32x4_t vo0 = vld1q_f32(o); o += 4; + float32x4_t vo1 = vld1q_f32(o); o += 4; + float32x4_t vo2 = vld1q_f32(o); o += 4; + float32x4_t vo3 = vld1q_f32(o); o += 4; + float32x4_t v_out0 = vaddq_f32(vo0, vacc0); + float32x4_t v_out1 = vaddq_f32(vo1, vacc1); + float32x4_t v_out2 = vaddq_f32(vo2, vacc2); + float32x4_t v_out3 = vaddq_f32(vo3, vacc3); + vst1q_f32(output, v_out0); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out1); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out2); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out3); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 16 * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + float32x4_t vacc[4]; + vacc[0] = vdupq_n_f32(0.f); + vacc[1] = vdupq_n_f32(0.f); + vacc[2] = vdupq_n_f32(0.f); + vacc[3] = vdupq_n_f32(0.f); + + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t num_full_chunks = channels >> 2; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_chunks; ++i) { + float32x4_t vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[i*4]))); + float32x4_t vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[i*4]))); + float32x4_t vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[i*4]))); + float32x4_t vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[i*4]))); + float32x4_t vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[i*4]))); + float32x4_t vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[i*4]))); + float32x4_t vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[i*4]))); + vacc[i] = vmlaq_f32(vacc[i], vin0, vin0); + vacc[i] = vmlaq_f32(vacc[i], vin1, vin1); + vacc[i] = vmlaq_f32(vacc[i], vin2, vin2); + vacc[i] = vmlaq_f32(vacc[i], vin3, vin3); + vacc[i] = vmlaq_f32(vacc[i], vin4, vin4); + vacc[i] = vmlaq_f32(vacc[i], vin5, vin5); + vacc[i] = vmlaq_f32(vacc[i], vin6, vin6); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (int i = 0; i < (channels + 4) >> 2; ++i) { + vacc[i] = vmulq_f32(vacc[i], vscale); + } + + float32x4_t vo[4]; + const float* o = (const float*) output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = vld1q_f32(o); o += 4; + } + float32x4_t v_out[4]; + for (int i = 0; i < num_full_chunks; ++i) { + v_out[i] = vaddq_f32(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + vst1q_f32(output, v_out[i]); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + } + + const size_t pos = channels >> 2; + channels &= 0x3; + float32x2_t vacc_low = vget_low_f32(vacc[pos]); + if (channels & 2) { + vst1_f32(output, vadd_f32(vacc_low, vld1_f32(output))); output = (void*) ((uintptr_t) output + 2 * sizeof(float)); + vacc_low = vget_high_f32(vacc[pos]); + } + if (channels & 1) { + vst1_lane_f32(output, vadd_f32(vacc_low, vld1_dup_f32(output)), 0); + } + } + } + } +} + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__neonfp16arith_u32( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const float32x4_t vscale = vdupq_n_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 32; channels -= 32) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + float32x4_t vacc0 = vdupq_n_f32(0.f); + float32x4_t vacc1 = vdupq_n_f32(0.f); + float32x4_t vacc2 = vdupq_n_f32(0.f); + float32x4_t vacc3 = vdupq_n_f32(0.f); + float32x4_t vacc4 = vdupq_n_f32(0.f); + float32x4_t vacc5 = vdupq_n_f32(0.f); + float32x4_t vacc6 = vdupq_n_f32(0.f); + float32x4_t vacc7 = vdupq_n_f32(0.f); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + float32x4_t vin0; + float32x4_t vin1; + float32x4_t vin2; + float32x4_t vin3; + float32x4_t vin4; + float32x4_t vin5; + float32x4_t vin6; + float32x4_t vin7; + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[28]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[28]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[28]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[28]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[28]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[28]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[28]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = vmulq_f32(vacc0, vscale); + vacc1 = vmulq_f32(vacc1, vscale); + vacc2 = vmulq_f32(vacc2, vscale); + vacc3 = vmulq_f32(vacc3, vscale); + vacc4 = vmulq_f32(vacc4, vscale); + vacc5 = vmulq_f32(vacc5, vscale); + vacc6 = vmulq_f32(vacc6, vscale); + vacc7 = vmulq_f32(vacc7, vscale); + + const float* o = (const float*) output; + float32x4_t vo0 = vld1q_f32(o); o += 4; + float32x4_t vo1 = vld1q_f32(o); o += 4; + float32x4_t vo2 = vld1q_f32(o); o += 4; + float32x4_t vo3 = vld1q_f32(o); o += 4; + float32x4_t vo4 = vld1q_f32(o); o += 4; + float32x4_t vo5 = vld1q_f32(o); o += 4; + float32x4_t vo6 = vld1q_f32(o); o += 4; + float32x4_t vo7 = vld1q_f32(o); o += 4; + float32x4_t v_out0 = vaddq_f32(vo0, vacc0); + float32x4_t v_out1 = vaddq_f32(vo1, vacc1); + float32x4_t v_out2 = vaddq_f32(vo2, vacc2); + float32x4_t v_out3 = vaddq_f32(vo3, vacc3); + float32x4_t v_out4 = vaddq_f32(vo4, vacc4); + float32x4_t v_out5 = vaddq_f32(vo5, vacc5); + float32x4_t v_out6 = vaddq_f32(vo6, vacc6); + float32x4_t v_out7 = vaddq_f32(vo7, vacc7); + vst1q_f32(output, v_out0); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out1); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out2); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out3); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out4); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out5); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out6); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out7); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 32 * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + float32x4_t vacc[8]; + vacc[0] = vdupq_n_f32(0.f); + vacc[1] = vdupq_n_f32(0.f); + vacc[2] = vdupq_n_f32(0.f); + vacc[3] = vdupq_n_f32(0.f); + vacc[4] = vdupq_n_f32(0.f); + vacc[5] = vdupq_n_f32(0.f); + vacc[6] = vdupq_n_f32(0.f); + vacc[7] = vdupq_n_f32(0.f); + + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t num_full_chunks = channels >> 2; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_chunks; ++i) { + float32x4_t vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[i*4]))); + float32x4_t vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[i*4]))); + float32x4_t vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[i*4]))); + float32x4_t vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[i*4]))); + float32x4_t vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[i*4]))); + float32x4_t vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[i*4]))); + float32x4_t vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[i*4]))); + vacc[i] = vmlaq_f32(vacc[i], vin0, vin0); + vacc[i] = vmlaq_f32(vacc[i], vin1, vin1); + vacc[i] = vmlaq_f32(vacc[i], vin2, vin2); + vacc[i] = vmlaq_f32(vacc[i], vin3, vin3); + vacc[i] = vmlaq_f32(vacc[i], vin4, vin4); + vacc[i] = vmlaq_f32(vacc[i], vin5, vin5); + vacc[i] = vmlaq_f32(vacc[i], vin6, vin6); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (int i = 0; i < (channels + 4) >> 2; ++i) { + vacc[i] = vmulq_f32(vacc[i], vscale); + } + + float32x4_t vo[8]; + const float* o = (const float*) output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = vld1q_f32(o); o += 4; + } + float32x4_t v_out[8]; + for (int i = 0; i < num_full_chunks; ++i) { + v_out[i] = vaddq_f32(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + vst1q_f32(output, v_out[i]); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + } + + const size_t pos = channels >> 2; + channels &= 0x3; + float32x2_t vacc_low = vget_low_f32(vacc[pos]); + if (channels & 2) { + vst1_f32(output, vadd_f32(vacc_low, vld1_f32(output))); output = (void*) ((uintptr_t) output + 2 * sizeof(float)); + vacc_low = vget_high_f32(vacc[pos]); + } + if (channels & 1) { + vst1_lane_f32(output, vadd_f32(vacc_low, vld1_dup_f32(output)), 0); + } + } + } + } +} + +void xnn_f16_f32acc_rdsum2_ukernel_7p7x__neonfp16arith_u64( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) +{ + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const float32x4_t vscale = vdupq_n_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 64; channels -= 64) { + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + + float32x4_t vacc0 = vdupq_n_f32(0.f); + float32x4_t vacc1 = vdupq_n_f32(0.f); + float32x4_t vacc2 = vdupq_n_f32(0.f); + float32x4_t vacc3 = vdupq_n_f32(0.f); + float32x4_t vacc4 = vdupq_n_f32(0.f); + float32x4_t vacc5 = vdupq_n_f32(0.f); + float32x4_t vacc6 = vdupq_n_f32(0.f); + float32x4_t vacc7 = vdupq_n_f32(0.f); + float32x4_t vacc8 = vdupq_n_f32(0.f); + float32x4_t vacc9 = vdupq_n_f32(0.f); + float32x4_t vacc10 = vdupq_n_f32(0.f); + float32x4_t vacc11 = vdupq_n_f32(0.f); + float32x4_t vacc12 = vdupq_n_f32(0.f); + float32x4_t vacc13 = vdupq_n_f32(0.f); + float32x4_t vacc14 = vdupq_n_f32(0.f); + float32x4_t vacc15 = vdupq_n_f32(0.f); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + float32x4_t vin0; + float32x4_t vin1; + float32x4_t vin2; + float32x4_t vin3; + float32x4_t vin4; + float32x4_t vin5; + float32x4_t vin6; + float32x4_t vin7; + float32x4_t vin8; + float32x4_t vin9; + float32x4_t vin10; + float32x4_t vin11; + float32x4_t vin12; + float32x4_t vin13; + float32x4_t vin14; + float32x4_t vin15; + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[28]))); + vin8 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[32]))); + vin9 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[36]))); + vin10 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[40]))); + vin11 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[44]))); + vin12 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[48]))); + vin13 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[52]))); + vin14 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[56]))); + vin15 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[60]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vacc8 = vmlaq_f32(vacc8, vin8, vin8); + vacc9 = vmlaq_f32(vacc9, vin9, vin9); + vacc10 = vmlaq_f32(vacc10, vin10, vin10); + vacc11 = vmlaq_f32(vacc11, vin11, vin11); + vacc12 = vmlaq_f32(vacc12, vin12, vin12); + vacc13 = vmlaq_f32(vacc13, vin13, vin13); + vacc14 = vmlaq_f32(vacc14, vin14, vin14); + vacc15 = vmlaq_f32(vacc15, vin15, vin15); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[28]))); + vin8 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[32]))); + vin9 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[36]))); + vin10 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[40]))); + vin11 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[44]))); + vin12 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[48]))); + vin13 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[52]))); + vin14 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[56]))); + vin15 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[60]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vacc8 = vmlaq_f32(vacc8, vin8, vin8); + vacc9 = vmlaq_f32(vacc9, vin9, vin9); + vacc10 = vmlaq_f32(vacc10, vin10, vin10); + vacc11 = vmlaq_f32(vacc11, vin11, vin11); + vacc12 = vmlaq_f32(vacc12, vin12, vin12); + vacc13 = vmlaq_f32(vacc13, vin13, vin13); + vacc14 = vmlaq_f32(vacc14, vin14, vin14); + vacc15 = vmlaq_f32(vacc15, vin15, vin15); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[28]))); + vin8 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[32]))); + vin9 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[36]))); + vin10 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[40]))); + vin11 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[44]))); + vin12 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[48]))); + vin13 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[52]))); + vin14 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[56]))); + vin15 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[60]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vacc8 = vmlaq_f32(vacc8, vin8, vin8); + vacc9 = vmlaq_f32(vacc9, vin9, vin9); + vacc10 = vmlaq_f32(vacc10, vin10, vin10); + vacc11 = vmlaq_f32(vacc11, vin11, vin11); + vacc12 = vmlaq_f32(vacc12, vin12, vin12); + vacc13 = vmlaq_f32(vacc13, vin13, vin13); + vacc14 = vmlaq_f32(vacc14, vin14, vin14); + vacc15 = vmlaq_f32(vacc15, vin15, vin15); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[28]))); + vin8 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[32]))); + vin9 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[36]))); + vin10 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[40]))); + vin11 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[44]))); + vin12 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[48]))); + vin13 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[52]))); + vin14 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[56]))); + vin15 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[60]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vacc8 = vmlaq_f32(vacc8, vin8, vin8); + vacc9 = vmlaq_f32(vacc9, vin9, vin9); + vacc10 = vmlaq_f32(vacc10, vin10, vin10); + vacc11 = vmlaq_f32(vacc11, vin11, vin11); + vacc12 = vmlaq_f32(vacc12, vin12, vin12); + vacc13 = vmlaq_f32(vacc13, vin13, vin13); + vacc14 = vmlaq_f32(vacc14, vin14, vin14); + vacc15 = vmlaq_f32(vacc15, vin15, vin15); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[28]))); + vin8 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[32]))); + vin9 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[36]))); + vin10 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[40]))); + vin11 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[44]))); + vin12 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[48]))); + vin13 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[52]))); + vin14 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[56]))); + vin15 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[60]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vacc8 = vmlaq_f32(vacc8, vin8, vin8); + vacc9 = vmlaq_f32(vacc9, vin9, vin9); + vacc10 = vmlaq_f32(vacc10, vin10, vin10); + vacc11 = vmlaq_f32(vacc11, vin11, vin11); + vacc12 = vmlaq_f32(vacc12, vin12, vin12); + vacc13 = vmlaq_f32(vacc13, vin13, vin13); + vacc14 = vmlaq_f32(vacc14, vin14, vin14); + vacc15 = vmlaq_f32(vacc15, vin15, vin15); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[28]))); + vin8 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[32]))); + vin9 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[36]))); + vin10 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[40]))); + vin11 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[44]))); + vin12 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[48]))); + vin13 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[52]))); + vin14 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[56]))); + vin15 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[60]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vacc8 = vmlaq_f32(vacc8, vin8, vin8); + vacc9 = vmlaq_f32(vacc9, vin9, vin9); + vacc10 = vmlaq_f32(vacc10, vin10, vin10); + vacc11 = vmlaq_f32(vacc11, vin11, vin11); + vacc12 = vmlaq_f32(vacc12, vin12, vin12); + vacc13 = vmlaq_f32(vacc13, vin13, vin13); + vacc14 = vmlaq_f32(vacc14, vin14, vin14); + vacc15 = vmlaq_f32(vacc15, vin15, vin15); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[12]))); + vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[16]))); + vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[20]))); + vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[24]))); + vin7 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[28]))); + vin8 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[32]))); + vin9 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[36]))); + vin10 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[40]))); + vin11 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[44]))); + vin12 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[48]))); + vin13 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[52]))); + vin14 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[56]))); + vin15 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[60]))); + vacc0 = vmlaq_f32(vacc0, vin0, vin0); + vacc1 = vmlaq_f32(vacc1, vin1, vin1); + vacc2 = vmlaq_f32(vacc2, vin2, vin2); + vacc3 = vmlaq_f32(vacc3, vin3, vin3); + vacc4 = vmlaq_f32(vacc4, vin4, vin4); + vacc5 = vmlaq_f32(vacc5, vin5, vin5); + vacc6 = vmlaq_f32(vacc6, vin6, vin6); + vacc7 = vmlaq_f32(vacc7, vin7, vin7); + vacc8 = vmlaq_f32(vacc8, vin8, vin8); + vacc9 = vmlaq_f32(vacc9, vin9, vin9); + vacc10 = vmlaq_f32(vacc10, vin10, vin10); + vacc11 = vmlaq_f32(vacc11, vin11, vin11); + vacc12 = vmlaq_f32(vacc12, vin12, vin12); + vacc13 = vmlaq_f32(vacc13, vin13, vin13); + vacc14 = vmlaq_f32(vacc14, vin14, vin14); + vacc15 = vmlaq_f32(vacc15, vin15, vin15); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = vmulq_f32(vacc0, vscale); + vacc1 = vmulq_f32(vacc1, vscale); + vacc2 = vmulq_f32(vacc2, vscale); + vacc3 = vmulq_f32(vacc3, vscale); + vacc4 = vmulq_f32(vacc4, vscale); + vacc5 = vmulq_f32(vacc5, vscale); + vacc6 = vmulq_f32(vacc6, vscale); + vacc7 = vmulq_f32(vacc7, vscale); + vacc8 = vmulq_f32(vacc8, vscale); + vacc9 = vmulq_f32(vacc9, vscale); + vacc10 = vmulq_f32(vacc10, vscale); + vacc11 = vmulq_f32(vacc11, vscale); + vacc12 = vmulq_f32(vacc12, vscale); + vacc13 = vmulq_f32(vacc13, vscale); + vacc14 = vmulq_f32(vacc14, vscale); + vacc15 = vmulq_f32(vacc15, vscale); + + const float* o = (const float*) output; + float32x4_t vo0 = vld1q_f32(o); o += 4; + float32x4_t vo1 = vld1q_f32(o); o += 4; + float32x4_t vo2 = vld1q_f32(o); o += 4; + float32x4_t vo3 = vld1q_f32(o); o += 4; + float32x4_t vo4 = vld1q_f32(o); o += 4; + float32x4_t vo5 = vld1q_f32(o); o += 4; + float32x4_t vo6 = vld1q_f32(o); o += 4; + float32x4_t vo7 = vld1q_f32(o); o += 4; + float32x4_t vo8 = vld1q_f32(o); o += 4; + float32x4_t vo9 = vld1q_f32(o); o += 4; + float32x4_t vo10 = vld1q_f32(o); o += 4; + float32x4_t vo11 = vld1q_f32(o); o += 4; + float32x4_t vo12 = vld1q_f32(o); o += 4; + float32x4_t vo13 = vld1q_f32(o); o += 4; + float32x4_t vo14 = vld1q_f32(o); o += 4; + float32x4_t vo15 = vld1q_f32(o); o += 4; + float32x4_t v_out0 = vaddq_f32(vo0, vacc0); + float32x4_t v_out1 = vaddq_f32(vo1, vacc1); + float32x4_t v_out2 = vaddq_f32(vo2, vacc2); + float32x4_t v_out3 = vaddq_f32(vo3, vacc3); + float32x4_t v_out4 = vaddq_f32(vo4, vacc4); + float32x4_t v_out5 = vaddq_f32(vo5, vacc5); + float32x4_t v_out6 = vaddq_f32(vo6, vacc6); + float32x4_t v_out7 = vaddq_f32(vo7, vacc7); + float32x4_t v_out8 = vaddq_f32(vo8, vacc8); + float32x4_t v_out9 = vaddq_f32(vo9, vacc9); + float32x4_t v_out10 = vaddq_f32(vo10, vacc10); + float32x4_t v_out11 = vaddq_f32(vo11, vacc11); + float32x4_t v_out12 = vaddq_f32(vo12, vacc12); + float32x4_t v_out13 = vaddq_f32(vo13, vacc13); + float32x4_t v_out14 = vaddq_f32(vo14, vacc14); + float32x4_t v_out15 = vaddq_f32(vo15, vacc15); + vst1q_f32(output, v_out0); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out1); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out2); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out3); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out4); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out5); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out6); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out7); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out8); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out9); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out10); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out11); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out12); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out13); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out14); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + vst1q_f32(output, v_out15); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + 64 * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input_row + 1 * input_stride1); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input_row + 2 * input_stride1); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input_row + 3 * input_stride1); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input_row + 4 * input_stride1); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input_row + 5 * input_stride1); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input_row + 6 * input_stride1); + float32x4_t vacc[16]; + vacc[0] = vdupq_n_f32(0.f); + vacc[1] = vdupq_n_f32(0.f); + vacc[2] = vdupq_n_f32(0.f); + vacc[3] = vdupq_n_f32(0.f); + vacc[4] = vdupq_n_f32(0.f); + vacc[5] = vdupq_n_f32(0.f); + vacc[6] = vdupq_n_f32(0.f); + vacc[7] = vdupq_n_f32(0.f); + vacc[8] = vdupq_n_f32(0.f); + vacc[9] = vdupq_n_f32(0.f); + vacc[10] = vdupq_n_f32(0.f); + vacc[11] = vdupq_n_f32(0.f); + vacc[12] = vdupq_n_f32(0.f); + vacc[13] = vdupq_n_f32(0.f); + vacc[14] = vdupq_n_f32(0.f); + vacc[15] = vdupq_n_f32(0.f); + + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t num_full_chunks = channels >> 2; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = (const uint16_t*) zero; + } + for (int i = 0; i < num_chunks; ++i) { + float32x4_t vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[i*4]))); + float32x4_t vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[i*4]))); + float32x4_t vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[i*4]))); + float32x4_t vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[i*4]))); + float32x4_t vin4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[i*4]))); + float32x4_t vin5 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[i*4]))); + float32x4_t vin6 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[i*4]))); + vacc[i] = vmlaq_f32(vacc[i], vin0, vin0); + vacc[i] = vmlaq_f32(vacc[i], vin1, vin1); + vacc[i] = vmlaq_f32(vacc[i], vin2, vin2); + vacc[i] = vmlaq_f32(vacc[i], vin3, vin3); + vacc[i] = vmlaq_f32(vacc[i], vin4, vin4); + vacc[i] = vmlaq_f32(vacc[i], vin5, vin5); + vacc[i] = vmlaq_f32(vacc[i], vin6, vin6); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (int i = 0; i < (channels + 4) >> 2; ++i) { + vacc[i] = vmulq_f32(vacc[i], vscale); + } + + float32x4_t vo[16]; + const float* o = (const float*) output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = vld1q_f32(o); o += 4; + } + float32x4_t v_out[16]; + for (int i = 0; i < num_full_chunks; ++i) { + v_out[i] = vaddq_f32(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + vst1q_f32(output, v_out[i]); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + } + + const size_t pos = channels >> 2; + channels &= 0x3; + float32x2_t vacc_low = vget_low_f32(vacc[pos]); + if (channels & 2) { + vst1_f32(output, vadd_f32(vacc_low, vld1_f32(output))); output = (void*) ((uintptr_t) output + 2 * sizeof(float)); + vacc_low = vget_high_f32(vacc[pos]); + } + if (channels & 1) { + vst1_lane_f32(output, vadd_f32(vacc_low, vld1_dup_f32(output)), 0); + } + } + } + } +} diff --git a/src/f16-f32acc-rdsum2/neon.c.in b/src/f16-f32acc-rdsum2/neon.c.in new file mode 100644 index 00000000000..83820e85f28 --- /dev/null +++ b/src/f16-f32acc-rdsum2/neon.c.in @@ -0,0 +1,148 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" + +$CHANNELS_BATCHES = tuple(int(cb) for cb in CHANNELS_BATCHES.split(",")) +$for CHANNELS_BATCH in CHANNELS_BATCHES: + $UNROLL = CHANNELS_BATCH >> 2 + + void xnn_f16_f32acc_rdsum2_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__neonfp16arith_u${CHANNELS_BATCH}( + size_t channels, + size_t k1, + size_t k2, + size_t k3, + const xnn_float16* input, + size_t input_stride1, + size_t input_stride2, + size_t input_stride3, + const xnn_float16* zero, + float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) + { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const float32x4_t vscale = vdupq_n_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const xnn_float16* input_row = (const xnn_float16*)((uintptr_t)input + j * input_stride2 + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = ${ACCUMULATORS} * input_stride1; + for (; channels >= ${CHANNELS_BATCH}; channels -= ${CHANNELS_BATCH}) { + const uint16_t* i0 = (const uint16_t*) input_row; + $for i in range(1, ACCUMULATORS): + const uint16_t* i${i} = (const uint16_t*) ((uintptr_t) input_row + ${i} * input_stride1); + + $for i in range(UNROLL): + float32x4_t vacc${i} = vdupq_n_f32(0.f); + + for (int r = k1; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = (const uint16_t*) zero; + } + $for c in range(UNROLL): + float32x4_t vin${c}; + $for j in range(ACCUMULATORS): + $for c in range(UNROLL): + vin${c} = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i${j}[${c*4}]))); + $for c in range(UNROLL): + vacc${c} = vmlaq_f32(vacc${c}, vin${c}, vin${c}); + $for N in range(0, ACCUMULATORS): + i${N} = (const uint16_t*) ((uintptr_t) i${N} + input_increment); + } + $for i in range(UNROLL): + vacc${i} = vmulq_f32(vacc${i}, vscale); + + const float* o = (const float*) output; + $for i in range(0, UNROLL): + float32x4_t vo${i} = vld1q_f32(o); o += 4; + $for i in range(0, UNROLL): + float32x4_t v_out${i} = vaddq_f32(vo${i}, vacc${i}); + $for i in range(0, UNROLL): + vst1q_f32(output, v_out${i}); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + + input_row = (const xnn_float16*) ((uintptr_t) input_row + ${CHANNELS_BATCH} * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = ${ACCUMULATORS} * input_stride1; + const uint16_t* i0 = (const uint16_t*) input_row; + $for i in range(1, ACCUMULATORS): + const uint16_t* i${i} = (const uint16_t*) ((uintptr_t) input_row + ${i} * input_stride1); + float32x4_t vacc[${UNROLL}]; + $for i in range(UNROLL): + vacc[${i}] = vdupq_n_f32(0.f); + + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t num_full_chunks = channels >> 2; + for (int r = k1; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = (const uint16_t*) zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = (const uint16_t*) zero; + } + for (int i = 0; i < num_chunks; ++i) { + $for c in range(ACCUMULATORS): + float32x4_t vin${c} = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i${c}[i*4]))); + $for c in range(ACCUMULATORS): + vacc[i] = vmlaq_f32(vacc[i], vin${c}, vin${c}); + } + $for N in range(ACCUMULATORS): + i${N} = (const uint16_t*) ((uintptr_t) i${N} + input_increment); + } + for (int i = 0; i < (channels + 4) >> 2; ++i) { + vacc[i] = vmulq_f32(vacc[i], vscale); + } + + float32x4_t vo[${UNROLL}]; + const float* o = (const float*) output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = vld1q_f32(o); o += 4; + } + float32x4_t v_out[${UNROLL}]; + for (int i = 0; i < num_full_chunks; ++i) { + v_out[i] = vaddq_f32(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + vst1q_f32(output, v_out[i]); output = (void*) ((uintptr_t) output + 4 * sizeof(float)); + } + + const size_t pos = channels >> 2; + channels &= 0x3; + float32x2_t vacc_low = vget_low_f32(vacc[pos]); + if (channels & 2) { + vst1_f32(output, vadd_f32(vacc_low, vld1_f32(output))); output = (void*) ((uintptr_t) output + 2 * sizeof(float)); + vacc_low = vget_high_f32(vacc[pos]); + } + if (channels & 1) { + vst1_lane_f32(output, vadd_f32(vacc_low, vld1_dup_f32(output)), 0); + } + } + } + } + } diff --git a/src/f16-f32acc-rsum2/avx512skx.c.in b/src/f16-f32acc-rsum2/avx512skx.c.in new file mode 100644 index 00000000000..1a91945167a --- /dev/null +++ b/src/f16-f32acc-rsum2/avx512skx.c.in @@ -0,0 +1,92 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/intrinsics-polyfill.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" + +$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(",")) +$SIMD_SIZE = BATCH_TILES[0] +$for BATCH_TILE in BATCH_TILES: + $assert BATCH_TILE % SIMD_SIZE == 0 + $assert BATCH_TILE >= SIMD_SIZE + $SIMD_TILE = BATCH_TILE // SIMD_SIZE + $ACCUMULATORS = SIMD_TILE + $while ACCUMULATORS: + $ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS + + void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u${BATCH_TILE}${ACC_SUFFIX}( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + $for A in range(ACCUMULATORS): + __m512 vacc${A} = _mm512_setzero_ps(); + for (; batch >= ${BATCH_TILE} * sizeof(uint16_t); batch -= ${BATCH_TILE} * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + $for N in range(1, SIMD_TILE): + __m512 vt${N} = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + ${N * 16}))); + i += ${BATCH_TILE}; + + $for N in range(SIMD_TILE): + vt${N} = _mm512_mul_ps(vt${N}, vt${N}); + + $for N in range(SIMD_TILE): + vacc${N % ACCUMULATORS} = _mm512_add_ps(vacc${N % ACCUMULATORS}, vt${N}); + } + $for N in range(0, SIMD_TILE - 1): + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc${N % ACCUMULATORS} = _mm512_add_ps(vacc${N % ACCUMULATORS}, vt); + } + $REDUCE_ACC = (ACCUMULATORS + 1)//2 + $while REDUCE_ACC > 0: + $for A in range(0, REDUCE_ACC): + $if A + REDUCE_ACC < ACCUMULATORS: + vacc${A} = _mm512_add_ps(vacc${A}, vacc${A + REDUCE_ACC}); + $REDUCE_ACC //= 2 + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; + } + $ACCUMULATORS //= 2 diff --git a/src/f16-f32acc-rsum2/f16-f32acc-rsum2.inc b/src/f16-f32acc-rsum2/f16-f32acc-rsum2.inc new file mode 100644 index 00000000000..a0ed715bbb3 --- /dev/null +++ b/src/f16-f32acc-rsum2/f16-f32acc-rsum2.inc @@ -0,0 +1,46 @@ +// clang-format off +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#if XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u8, 8, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u16, 16, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u24, 24, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u32, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u8_acc2, 8, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u16_acc2, 16, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u24_acc3, 24, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u32_acc2, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u16_acc4, 16, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u32_acc4, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u24_acc6, 24, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u32_acc8, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +#endif // XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rsum2_ukernel__f16c_u8, 8, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rsum2_ukernel__f16c_u16, 16, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rsum2_ukernel__f16c_u24, 24, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rsum2_ukernel__f16c_u32, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rsum2_ukernel__f16c_u16_acc2, 16, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rsum2_ukernel__f16c_u24_acc3, 24, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rsum2_ukernel__f16c_u32_acc2, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_rsum2_ukernel__f16c_u32_acc4, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + +#if XNN_ENABLE_AVX512SKX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u16, 16, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u32, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u48, 48, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u64, 64, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u128, 128, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u128_acc2, 128, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u32_acc2, 32, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u48_acc3, 48, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u64_acc2, 64, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u64_acc4, 64, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512skx, xnn_f16_f32acc_rsum2_ukernel__avx512skx_u128_acc4, 128, false, xnn_float16, float, struct xnn_f16_f32acc_scale_params, xnn_init_f16_f32acc_scale_scalar_params) +#endif // XNN_ENABLE_AVX512SKX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) + diff --git a/src/f16-f32acc-rsum2/f16c.c.in b/src/f16-f32acc-rsum2/f16c.c.in new file mode 100644 index 00000000000..405018b4dcc --- /dev/null +++ b/src/f16-f32acc-rsum2/f16c.c.in @@ -0,0 +1,97 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/intrinsics-polyfill.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/unaligned.h" + +$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(",")) +$SIMD_SIZE = BATCH_TILES[0] +$for BATCH_TILE in BATCH_TILES: + $assert BATCH_TILE % SIMD_SIZE == 0 + $assert BATCH_TILE >= SIMD_SIZE + $SIMD_TILE = BATCH_TILE // SIMD_SIZE + $ACCUMULATORS = SIMD_TILE + $while ACCUMULATORS: + $ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS + + void xnn_f16_f32acc_rsum2_ukernel__f16c_u${BATCH_TILE}${ACC_SUFFIX}( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0}; + + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + $for A in range(ACCUMULATORS): + __m256 vacc${A} = _mm256_setzero_ps(); + for (; batch >= ${BATCH_TILE} * sizeof(uint16_t); batch -= ${BATCH_TILE} * sizeof(uint16_t)) { + __m256 vt0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + $for N in range(1, SIMD_TILE): + __m256 vt${N} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + ${N * 8}))); + i += ${BATCH_TILE}; + + $for N in range(SIMD_TILE): + vt${N} = _mm256_mul_ps(vt${N}, vt${N}); + + $for N in range(SIMD_TILE): + vacc${N % ACCUMULATORS} = _mm256_add_ps(vacc${N % ACCUMULATORS}, vt${N}); + } + $for N in range(0, SIMD_TILE - 1): + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc${N % ACCUMULATORS} = _mm256_add_ps(vacc${N % ACCUMULATORS}, vt); + } + $REDUCE_ACC = (ACCUMULATORS + 1) // 2 + $while REDUCE_ACC > 0: + $for A in range(0, REDUCE_ACC): + $if A + REDUCE_ACC < ACCUMULATORS: + vacc${A} = _mm256_add_ps(vacc${A}, vacc${A + REDUCE_ACC}); + $REDUCE_ACC //= 2 + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 7 * sizeof(uint16_t)); + const __m128i vmask = + _mm_loadu_si128((const __m128i*) ((uintptr_t) &mask_table[7] - batch)); + const __m128i vh = + _mm_castps_si128(_mm_maskload_ps((const float*) i, vmask)); + __m256 vt = _mm256_cvtph_ps(vh); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + i = (const void*) ((uintptr_t) i + batch); + if (batch & (1 * sizeof(uint16_t))) { + const __m128i vh = _mm_insert_epi16(_mm_setzero_si128(), + (int) unaligned_load_u16(i - 1), 0); + __m256 vt = _mm256_zextps128_ps256(_mm_cvtph_ps(vh)); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + } + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), + _mm256_extractf128_ps(vacc0, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; + } + $ACCUMULATORS //= 2 diff --git a/src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-avx512skx.c b/src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-avx512skx.c new file mode 100644 index 00000000000..a093b2b0218 --- /dev/null +++ b/src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-avx512skx.c @@ -0,0 +1,1040 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f16-f32acc-rsum2/avx512skx.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/intrinsics-polyfill.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" + + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u16( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + for (; batch >= 16 * sizeof(uint16_t); batch -= 16 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + + vt0 = _mm512_mul_ps(vt0, vt0); + + vacc0 = _mm512_add_ps(vacc0, vt0); + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u32_acc2( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + i += 32; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc1 = _mm512_add_ps(vacc1, vt1); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + vacc0 = _mm512_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u32( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + i += 32; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc0 = _mm512_add_ps(vacc0, vt1); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u48_acc3( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + __m512 vacc2 = _mm512_setzero_ps(); + for (; batch >= 48 * sizeof(uint16_t); batch -= 48 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + __m512 vt2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 32))); + i += 48; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + vt2 = _mm512_mul_ps(vt2, vt2); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc1 = _mm512_add_ps(vacc1, vt1); + vacc2 = _mm512_add_ps(vacc2, vt2); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc1 = _mm512_add_ps(vacc1, vt); + } + vacc0 = _mm512_add_ps(vacc0, vacc2); + vacc0 = _mm512_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u48( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + for (; batch >= 48 * sizeof(uint16_t); batch -= 48 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + __m512 vt2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 32))); + i += 48; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + vt2 = _mm512_mul_ps(vt2, vt2); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc0 = _mm512_add_ps(vacc0, vt1); + vacc0 = _mm512_add_ps(vacc0, vt2); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u64_acc4( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + __m512 vacc2 = _mm512_setzero_ps(); + __m512 vacc3 = _mm512_setzero_ps(); + for (; batch >= 64 * sizeof(uint16_t); batch -= 64 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + __m512 vt2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 32))); + __m512 vt3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 48))); + i += 64; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + vt2 = _mm512_mul_ps(vt2, vt2); + vt3 = _mm512_mul_ps(vt3, vt3); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc1 = _mm512_add_ps(vacc1, vt1); + vacc2 = _mm512_add_ps(vacc2, vt2); + vacc3 = _mm512_add_ps(vacc3, vt3); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc1 = _mm512_add_ps(vacc1, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc2 = _mm512_add_ps(vacc2, vt); + } + vacc0 = _mm512_add_ps(vacc0, vacc2); + vacc1 = _mm512_add_ps(vacc1, vacc3); + vacc0 = _mm512_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u64_acc2( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + for (; batch >= 64 * sizeof(uint16_t); batch -= 64 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + __m512 vt2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 32))); + __m512 vt3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 48))); + i += 64; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + vt2 = _mm512_mul_ps(vt2, vt2); + vt3 = _mm512_mul_ps(vt3, vt3); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc1 = _mm512_add_ps(vacc1, vt1); + vacc0 = _mm512_add_ps(vacc0, vt2); + vacc1 = _mm512_add_ps(vacc1, vt3); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc1 = _mm512_add_ps(vacc1, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + vacc0 = _mm512_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u64( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + for (; batch >= 64 * sizeof(uint16_t); batch -= 64 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + __m512 vt2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 32))); + __m512 vt3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 48))); + i += 64; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + vt2 = _mm512_mul_ps(vt2, vt2); + vt3 = _mm512_mul_ps(vt3, vt3); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc0 = _mm512_add_ps(vacc0, vt1); + vacc0 = _mm512_add_ps(vacc0, vt2); + vacc0 = _mm512_add_ps(vacc0, vt3); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u128_acc8( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + __m512 vacc2 = _mm512_setzero_ps(); + __m512 vacc3 = _mm512_setzero_ps(); + __m512 vacc4 = _mm512_setzero_ps(); + __m512 vacc5 = _mm512_setzero_ps(); + __m512 vacc6 = _mm512_setzero_ps(); + __m512 vacc7 = _mm512_setzero_ps(); + for (; batch >= 128 * sizeof(uint16_t); batch -= 128 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + __m512 vt2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 32))); + __m512 vt3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 48))); + __m512 vt4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 64))); + __m512 vt5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 80))); + __m512 vt6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 96))); + __m512 vt7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 112))); + i += 128; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + vt2 = _mm512_mul_ps(vt2, vt2); + vt3 = _mm512_mul_ps(vt3, vt3); + vt4 = _mm512_mul_ps(vt4, vt4); + vt5 = _mm512_mul_ps(vt5, vt5); + vt6 = _mm512_mul_ps(vt6, vt6); + vt7 = _mm512_mul_ps(vt7, vt7); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc1 = _mm512_add_ps(vacc1, vt1); + vacc2 = _mm512_add_ps(vacc2, vt2); + vacc3 = _mm512_add_ps(vacc3, vt3); + vacc4 = _mm512_add_ps(vacc4, vt4); + vacc5 = _mm512_add_ps(vacc5, vt5); + vacc6 = _mm512_add_ps(vacc6, vt6); + vacc7 = _mm512_add_ps(vacc7, vt7); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc1 = _mm512_add_ps(vacc1, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc2 = _mm512_add_ps(vacc2, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc3 = _mm512_add_ps(vacc3, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc4 = _mm512_add_ps(vacc4, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc5 = _mm512_add_ps(vacc5, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc6 = _mm512_add_ps(vacc6, vt); + } + vacc0 = _mm512_add_ps(vacc0, vacc4); + vacc1 = _mm512_add_ps(vacc1, vacc5); + vacc2 = _mm512_add_ps(vacc2, vacc6); + vacc3 = _mm512_add_ps(vacc3, vacc7); + vacc0 = _mm512_add_ps(vacc0, vacc2); + vacc1 = _mm512_add_ps(vacc1, vacc3); + vacc0 = _mm512_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u128_acc4( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + __m512 vacc2 = _mm512_setzero_ps(); + __m512 vacc3 = _mm512_setzero_ps(); + for (; batch >= 128 * sizeof(uint16_t); batch -= 128 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + __m512 vt2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 32))); + __m512 vt3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 48))); + __m512 vt4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 64))); + __m512 vt5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 80))); + __m512 vt6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 96))); + __m512 vt7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 112))); + i += 128; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + vt2 = _mm512_mul_ps(vt2, vt2); + vt3 = _mm512_mul_ps(vt3, vt3); + vt4 = _mm512_mul_ps(vt4, vt4); + vt5 = _mm512_mul_ps(vt5, vt5); + vt6 = _mm512_mul_ps(vt6, vt6); + vt7 = _mm512_mul_ps(vt7, vt7); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc1 = _mm512_add_ps(vacc1, vt1); + vacc2 = _mm512_add_ps(vacc2, vt2); + vacc3 = _mm512_add_ps(vacc3, vt3); + vacc0 = _mm512_add_ps(vacc0, vt4); + vacc1 = _mm512_add_ps(vacc1, vt5); + vacc2 = _mm512_add_ps(vacc2, vt6); + vacc3 = _mm512_add_ps(vacc3, vt7); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc1 = _mm512_add_ps(vacc1, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc2 = _mm512_add_ps(vacc2, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc3 = _mm512_add_ps(vacc3, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc1 = _mm512_add_ps(vacc1, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc2 = _mm512_add_ps(vacc2, vt); + } + vacc0 = _mm512_add_ps(vacc0, vacc2); + vacc1 = _mm512_add_ps(vacc1, vacc3); + vacc0 = _mm512_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u128_acc2( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + for (; batch >= 128 * sizeof(uint16_t); batch -= 128 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + __m512 vt2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 32))); + __m512 vt3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 48))); + __m512 vt4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 64))); + __m512 vt5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 80))); + __m512 vt6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 96))); + __m512 vt7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 112))); + i += 128; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + vt2 = _mm512_mul_ps(vt2, vt2); + vt3 = _mm512_mul_ps(vt3, vt3); + vt4 = _mm512_mul_ps(vt4, vt4); + vt5 = _mm512_mul_ps(vt5, vt5); + vt6 = _mm512_mul_ps(vt6, vt6); + vt7 = _mm512_mul_ps(vt7, vt7); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc1 = _mm512_add_ps(vacc1, vt1); + vacc0 = _mm512_add_ps(vacc0, vt2); + vacc1 = _mm512_add_ps(vacc1, vt3); + vacc0 = _mm512_add_ps(vacc0, vt4); + vacc1 = _mm512_add_ps(vacc1, vt5); + vacc0 = _mm512_add_ps(vacc0, vt6); + vacc1 = _mm512_add_ps(vacc1, vt7); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc1 = _mm512_add_ps(vacc1, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc1 = _mm512_add_ps(vacc1, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc1 = _mm512_add_ps(vacc1, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + vacc0 = _mm512_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__avx512skx_u128( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m512 vacc0 = _mm512_setzero_ps(); + for (; batch >= 128 * sizeof(uint16_t); batch -= 128 * sizeof(uint16_t)) { + __m512 vt0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + __m512 vt1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 16))); + __m512 vt2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 32))); + __m512 vt3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 48))); + __m512 vt4 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 64))); + __m512 vt5 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 80))); + __m512 vt6 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 96))); + __m512 vt7 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (i + 112))); + i += 128; + + vt0 = _mm512_mul_ps(vt0, vt0); + vt1 = _mm512_mul_ps(vt1, vt1); + vt2 = _mm512_mul_ps(vt2, vt2); + vt3 = _mm512_mul_ps(vt3, vt3); + vt4 = _mm512_mul_ps(vt4, vt4); + vt5 = _mm512_mul_ps(vt5, vt5); + vt6 = _mm512_mul_ps(vt6, vt6); + vt7 = _mm512_mul_ps(vt7, vt7); + + vacc0 = _mm512_add_ps(vacc0, vt0); + vacc0 = _mm512_add_ps(vacc0, vt1); + vacc0 = _mm512_add_ps(vacc0, vt2); + vacc0 = _mm512_add_ps(vacc0, vt3); + vacc0 = _mm512_add_ps(vacc0, vt4); + vacc0 = _mm512_add_ps(vacc0, vt5); + vacc0 = _mm512_add_ps(vacc0, vt6); + vacc0 = _mm512_add_ps(vacc0, vt7); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if (batch >= 16 * sizeof(uint16_t)) { + __m512 vt = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i)); + i += 16; + batch -= 16 * sizeof(uint16_t); + vt = _mm512_mul_ps(vt, vt); + vacc0 = _mm512_add_ps(vacc0, vt); + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 15 * sizeof(uint16_t)); + + // Prepare mask for valid elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask16 vmask = + _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512 vt = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i)); + + vt = _mm512_mul_ps(vt, vt); + + vacc0 = _mm512_add_ps(vacc0, vt); + } + const __m256 vacc256 = _mm256_add_ps( + _mm512_castps512_ps256(vacc0), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(vacc0), 1))); + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), + _mm256_extractf128_ps(vacc256, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} diff --git a/src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-f16c.c b/src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-f16c.c new file mode 100644 index 00000000000..3ef6f1a069f --- /dev/null +++ b/src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-f16c.c @@ -0,0 +1,587 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f16-f32acc-rsum2/f16c.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/intrinsics-polyfill.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/unaligned.h" + + +void xnn_f16_f32acc_rsum2_ukernel__f16c_u8( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0}; + + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m256 vacc0 = _mm256_setzero_ps(); + for (; batch >= 8 * sizeof(uint16_t); batch -= 8 * sizeof(uint16_t)) { + __m256 vt0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + + vt0 = _mm256_mul_ps(vt0, vt0); + + vacc0 = _mm256_add_ps(vacc0, vt0); + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 7 * sizeof(uint16_t)); + const __m128i vmask = + _mm_loadu_si128((const __m128i*) ((uintptr_t) &mask_table[7] - batch)); + const __m128i vh = + _mm_castps_si128(_mm_maskload_ps((const float*) i, vmask)); + __m256 vt = _mm256_cvtph_ps(vh); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + i = (const void*) ((uintptr_t) i + batch); + if (batch & (1 * sizeof(uint16_t))) { + const __m128i vh = _mm_insert_epi16(_mm_setzero_si128(), + (int) unaligned_load_u16(i - 1), 0); + __m256 vt = _mm256_zextps128_ps256(_mm_cvtph_ps(vh)); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + } + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), + _mm256_extractf128_ps(vacc0, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__f16c_u16_acc2( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0}; + + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + for (; batch >= 16 * sizeof(uint16_t); batch -= 16 * sizeof(uint16_t)) { + __m256 vt0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + __m256 vt1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8))); + i += 16; + + vt0 = _mm256_mul_ps(vt0, vt0); + vt1 = _mm256_mul_ps(vt1, vt1); + + vacc0 = _mm256_add_ps(vacc0, vt0); + vacc1 = _mm256_add_ps(vacc1, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + vacc0 = _mm256_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 7 * sizeof(uint16_t)); + const __m128i vmask = + _mm_loadu_si128((const __m128i*) ((uintptr_t) &mask_table[7] - batch)); + const __m128i vh = + _mm_castps_si128(_mm_maskload_ps((const float*) i, vmask)); + __m256 vt = _mm256_cvtph_ps(vh); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + i = (const void*) ((uintptr_t) i + batch); + if (batch & (1 * sizeof(uint16_t))) { + const __m128i vh = _mm_insert_epi16(_mm_setzero_si128(), + (int) unaligned_load_u16(i - 1), 0); + __m256 vt = _mm256_zextps128_ps256(_mm_cvtph_ps(vh)); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + } + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), + _mm256_extractf128_ps(vacc0, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__f16c_u16( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0}; + + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m256 vacc0 = _mm256_setzero_ps(); + for (; batch >= 16 * sizeof(uint16_t); batch -= 16 * sizeof(uint16_t)) { + __m256 vt0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + __m256 vt1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8))); + i += 16; + + vt0 = _mm256_mul_ps(vt0, vt0); + vt1 = _mm256_mul_ps(vt1, vt1); + + vacc0 = _mm256_add_ps(vacc0, vt0); + vacc0 = _mm256_add_ps(vacc0, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 7 * sizeof(uint16_t)); + const __m128i vmask = + _mm_loadu_si128((const __m128i*) ((uintptr_t) &mask_table[7] - batch)); + const __m128i vh = + _mm_castps_si128(_mm_maskload_ps((const float*) i, vmask)); + __m256 vt = _mm256_cvtph_ps(vh); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + i = (const void*) ((uintptr_t) i + batch); + if (batch & (1 * sizeof(uint16_t))) { + const __m128i vh = _mm_insert_epi16(_mm_setzero_si128(), + (int) unaligned_load_u16(i - 1), 0); + __m256 vt = _mm256_zextps128_ps256(_mm_cvtph_ps(vh)); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + } + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), + _mm256_extractf128_ps(vacc0, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__f16c_u24_acc3( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0}; + + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + for (; batch >= 24 * sizeof(uint16_t); batch -= 24 * sizeof(uint16_t)) { + __m256 vt0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + __m256 vt1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8))); + __m256 vt2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 16))); + i += 24; + + vt0 = _mm256_mul_ps(vt0, vt0); + vt1 = _mm256_mul_ps(vt1, vt1); + vt2 = _mm256_mul_ps(vt2, vt2); + + vacc0 = _mm256_add_ps(vacc0, vt0); + vacc1 = _mm256_add_ps(vacc1, vt1); + vacc2 = _mm256_add_ps(vacc2, vt2); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc1 = _mm256_add_ps(vacc1, vt); + } + vacc0 = _mm256_add_ps(vacc0, vacc2); + vacc0 = _mm256_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 7 * sizeof(uint16_t)); + const __m128i vmask = + _mm_loadu_si128((const __m128i*) ((uintptr_t) &mask_table[7] - batch)); + const __m128i vh = + _mm_castps_si128(_mm_maskload_ps((const float*) i, vmask)); + __m256 vt = _mm256_cvtph_ps(vh); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + i = (const void*) ((uintptr_t) i + batch); + if (batch & (1 * sizeof(uint16_t))) { + const __m128i vh = _mm_insert_epi16(_mm_setzero_si128(), + (int) unaligned_load_u16(i - 1), 0); + __m256 vt = _mm256_zextps128_ps256(_mm_cvtph_ps(vh)); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + } + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), + _mm256_extractf128_ps(vacc0, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__f16c_u24( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0}; + + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m256 vacc0 = _mm256_setzero_ps(); + for (; batch >= 24 * sizeof(uint16_t); batch -= 24 * sizeof(uint16_t)) { + __m256 vt0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + __m256 vt1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8))); + __m256 vt2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 16))); + i += 24; + + vt0 = _mm256_mul_ps(vt0, vt0); + vt1 = _mm256_mul_ps(vt1, vt1); + vt2 = _mm256_mul_ps(vt2, vt2); + + vacc0 = _mm256_add_ps(vacc0, vt0); + vacc0 = _mm256_add_ps(vacc0, vt1); + vacc0 = _mm256_add_ps(vacc0, vt2); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 7 * sizeof(uint16_t)); + const __m128i vmask = + _mm_loadu_si128((const __m128i*) ((uintptr_t) &mask_table[7] - batch)); + const __m128i vh = + _mm_castps_si128(_mm_maskload_ps((const float*) i, vmask)); + __m256 vt = _mm256_cvtph_ps(vh); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + i = (const void*) ((uintptr_t) i + batch); + if (batch & (1 * sizeof(uint16_t))) { + const __m128i vh = _mm_insert_epi16(_mm_setzero_si128(), + (int) unaligned_load_u16(i - 1), 0); + __m256 vt = _mm256_zextps128_ps256(_mm_cvtph_ps(vh)); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + } + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), + _mm256_extractf128_ps(vacc0, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__f16c_u32_acc4( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0}; + + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + __m256 vacc3 = _mm256_setzero_ps(); + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + __m256 vt0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + __m256 vt1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8))); + __m256 vt2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 16))); + __m256 vt3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 24))); + i += 32; + + vt0 = _mm256_mul_ps(vt0, vt0); + vt1 = _mm256_mul_ps(vt1, vt1); + vt2 = _mm256_mul_ps(vt2, vt2); + vt3 = _mm256_mul_ps(vt3, vt3); + + vacc0 = _mm256_add_ps(vacc0, vt0); + vacc1 = _mm256_add_ps(vacc1, vt1); + vacc2 = _mm256_add_ps(vacc2, vt2); + vacc3 = _mm256_add_ps(vacc3, vt3); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc1 = _mm256_add_ps(vacc1, vt); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc2 = _mm256_add_ps(vacc2, vt); + } + vacc0 = _mm256_add_ps(vacc0, vacc2); + vacc1 = _mm256_add_ps(vacc1, vacc3); + vacc0 = _mm256_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 7 * sizeof(uint16_t)); + const __m128i vmask = + _mm_loadu_si128((const __m128i*) ((uintptr_t) &mask_table[7] - batch)); + const __m128i vh = + _mm_castps_si128(_mm_maskload_ps((const float*) i, vmask)); + __m256 vt = _mm256_cvtph_ps(vh); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + i = (const void*) ((uintptr_t) i + batch); + if (batch & (1 * sizeof(uint16_t))) { + const __m128i vh = _mm_insert_epi16(_mm_setzero_si128(), + (int) unaligned_load_u16(i - 1), 0); + __m256 vt = _mm256_zextps128_ps256(_mm_cvtph_ps(vh)); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + } + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), + _mm256_extractf128_ps(vacc0, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__f16c_u32_acc2( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0}; + + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + __m256 vt0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + __m256 vt1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8))); + __m256 vt2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 16))); + __m256 vt3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 24))); + i += 32; + + vt0 = _mm256_mul_ps(vt0, vt0); + vt1 = _mm256_mul_ps(vt1, vt1); + vt2 = _mm256_mul_ps(vt2, vt2); + vt3 = _mm256_mul_ps(vt3, vt3); + + vacc0 = _mm256_add_ps(vacc0, vt0); + vacc1 = _mm256_add_ps(vacc1, vt1); + vacc0 = _mm256_add_ps(vacc0, vt2); + vacc1 = _mm256_add_ps(vacc1, vt3); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc1 = _mm256_add_ps(vacc1, vt); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + vacc0 = _mm256_add_ps(vacc0, vacc1); + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 7 * sizeof(uint16_t)); + const __m128i vmask = + _mm_loadu_si128((const __m128i*) ((uintptr_t) &mask_table[7] - batch)); + const __m128i vh = + _mm_castps_si128(_mm_maskload_ps((const float*) i, vmask)); + __m256 vt = _mm256_cvtph_ps(vh); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + i = (const void*) ((uintptr_t) i + batch); + if (batch & (1 * sizeof(uint16_t))) { + const __m128i vh = _mm_insert_epi16(_mm_setzero_si128(), + (int) unaligned_load_u16(i - 1), 0); + __m256 vt = _mm256_zextps128_ps256(_mm_cvtph_ps(vh)); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + } + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), + _mm256_extractf128_ps(vacc0, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__f16c_u32( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0}; + + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + __m256 vacc0 = _mm256_setzero_ps(); + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + __m256 vt0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + __m256 vt1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8))); + __m256 vt2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 16))); + __m256 vt3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 24))); + i += 32; + + vt0 = _mm256_mul_ps(vt0, vt0); + vt1 = _mm256_mul_ps(vt1, vt1); + vt2 = _mm256_mul_ps(vt2, vt2); + vt3 = _mm256_mul_ps(vt3, vt3); + + vacc0 = _mm256_add_ps(vacc0, vt0); + vacc0 = _mm256_add_ps(vacc0, vt1); + vacc0 = _mm256_add_ps(vacc0, vt2); + vacc0 = _mm256_add_ps(vacc0, vt3); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + if (batch >= 8 * sizeof(uint16_t)) { + __m256 vt = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); + i += 8; + batch -= 8 * sizeof(uint16_t); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 7 * sizeof(uint16_t)); + const __m128i vmask = + _mm_loadu_si128((const __m128i*) ((uintptr_t) &mask_table[7] - batch)); + const __m128i vh = + _mm_castps_si128(_mm_maskload_ps((const float*) i, vmask)); + __m256 vt = _mm256_cvtph_ps(vh); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + i = (const void*) ((uintptr_t) i + batch); + if (batch & (1 * sizeof(uint16_t))) { + const __m128i vh = _mm_insert_epi16(_mm_setzero_si128(), + (int) unaligned_load_u16(i - 1), 0); + __m256 vt = _mm256_zextps128_ps256(_mm_cvtph_ps(vh)); + vt = _mm256_mul_ps(vt, vt); + vacc0 = _mm256_add_ps(vacc0, vt); + } + } + __m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), + _mm256_extractf128_ps(vacc0, 1)); + vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc)); + vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc)); + vacc = _mm_mul_ss(vacc, _mm_load_ss(¶ms->scalar.scale)); + + float vout = _mm_cvtss_f32(vacc); + *output += vout; +} diff --git a/src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-neonfp16arith.c b/src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-neonfp16arith.c new file mode 100644 index 00000000000..a1a94892495 --- /dev/null +++ b/src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-neonfp16arith.c @@ -0,0 +1,1152 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f16-f32acc-rsum2/neonfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" + + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u8_acc2( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + float32x4_t vacc1 = vmovq_n_f32(0.0f); + for (; batch >= 8 * sizeof(uint16_t); batch -= 8 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + vacc0 = vaddq_f32(vacc0, vacc1); + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u8( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + for (; batch >= 8 * sizeof(uint16_t); batch -= 8 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + } + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u16_acc4( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + float32x4_t vacc1 = vmovq_n_f32(0.0f); + float32x4_t vacc2 = vmovq_n_f32(0.0f); + float32x4_t vacc3 = vmovq_n_f32(0.0f); + for (; batch >= 16 * sizeof(uint16_t); batch -= 16 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh23 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + float32x4_t vt2 = vcvt_f32_f16(vget_low_f16(vh23)); + float32x4_t vt3 = vcvt_f32_f16(vget_high_f16(vh23)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + vt2 = vmulq_f32(vt2, vt2); + vt3 = vmulq_f32(vt3, vt3); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + vacc2 = vaddq_f32(vacc2, vt2); + vacc3 = vaddq_f32(vacc3, vt3); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + vacc0 = vaddq_f32(vacc0, vacc2); + vacc1 = vaddq_f32(vacc1, vacc3); + vacc0 = vaddq_f32(vacc0, vacc1); + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u16_acc2( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + float32x4_t vacc1 = vmovq_n_f32(0.0f); + for (; batch >= 16 * sizeof(uint16_t); batch -= 16 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh23 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + float32x4_t vt2 = vcvt_f32_f16(vget_low_f16(vh23)); + float32x4_t vt3 = vcvt_f32_f16(vget_high_f16(vh23)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + vt2 = vmulq_f32(vt2, vt2); + vt3 = vmulq_f32(vt3, vt3); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + vacc0 = vaddq_f32(vacc0, vt2); + vacc1 = vaddq_f32(vacc1, vt3); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + vacc0 = vaddq_f32(vacc0, vacc1); + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u16( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + for (; batch >= 16 * sizeof(uint16_t); batch -= 16 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh23 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + float32x4_t vt2 = vcvt_f32_f16(vget_low_f16(vh23)); + float32x4_t vt3 = vcvt_f32_f16(vget_high_f16(vh23)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + vt2 = vmulq_f32(vt2, vt2); + vt3 = vmulq_f32(vt3, vt3); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + vacc0 = vaddq_f32(vacc0, vt2); + vacc0 = vaddq_f32(vacc0, vt3); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + } + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u24_acc6( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + float32x4_t vacc1 = vmovq_n_f32(0.0f); + float32x4_t vacc2 = vmovq_n_f32(0.0f); + float32x4_t vacc3 = vmovq_n_f32(0.0f); + float32x4_t vacc4 = vmovq_n_f32(0.0f); + float32x4_t vacc5 = vmovq_n_f32(0.0f); + for (; batch >= 24 * sizeof(uint16_t); batch -= 24 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh23 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh45 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + float32x4_t vt2 = vcvt_f32_f16(vget_low_f16(vh23)); + float32x4_t vt3 = vcvt_f32_f16(vget_high_f16(vh23)); + float32x4_t vt4 = vcvt_f32_f16(vget_low_f16(vh45)); + float32x4_t vt5 = vcvt_f32_f16(vget_high_f16(vh45)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + vt2 = vmulq_f32(vt2, vt2); + vt3 = vmulq_f32(vt3, vt3); + vt4 = vmulq_f32(vt4, vt4); + vt5 = vmulq_f32(vt5, vt5); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + vacc2 = vaddq_f32(vacc2, vt2); + vacc3 = vaddq_f32(vacc3, vt3); + vacc4 = vaddq_f32(vacc4, vt4); + vacc5 = vaddq_f32(vacc5, vt5); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc2 = vaddq_f32(vacc2, vt0); + vacc3 = vaddq_f32(vacc3, vt1); + } + vacc0 = vaddq_f32(vacc0, vacc3); + vacc1 = vaddq_f32(vacc1, vacc4); + vacc2 = vaddq_f32(vacc2, vacc5); + vacc1 = vaddq_f32(vacc1, vacc2); + vacc0 = vaddq_f32(vacc0, vacc1); + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u24_acc3( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + float32x4_t vacc1 = vmovq_n_f32(0.0f); + float32x4_t vacc2 = vmovq_n_f32(0.0f); + for (; batch >= 24 * sizeof(uint16_t); batch -= 24 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh23 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh45 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + float32x4_t vt2 = vcvt_f32_f16(vget_low_f16(vh23)); + float32x4_t vt3 = vcvt_f32_f16(vget_high_f16(vh23)); + float32x4_t vt4 = vcvt_f32_f16(vget_low_f16(vh45)); + float32x4_t vt5 = vcvt_f32_f16(vget_high_f16(vh45)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + vt2 = vmulq_f32(vt2, vt2); + vt3 = vmulq_f32(vt3, vt3); + vt4 = vmulq_f32(vt4, vt4); + vt5 = vmulq_f32(vt5, vt5); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + vacc2 = vaddq_f32(vacc2, vt2); + vacc0 = vaddq_f32(vacc0, vt3); + vacc1 = vaddq_f32(vacc1, vt4); + vacc2 = vaddq_f32(vacc2, vt5); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc2 = vaddq_f32(vacc2, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + } + vacc1 = vaddq_f32(vacc1, vacc2); + vacc0 = vaddq_f32(vacc0, vacc1); + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u24( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + for (; batch >= 24 * sizeof(uint16_t); batch -= 24 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh23 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh45 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + float32x4_t vt2 = vcvt_f32_f16(vget_low_f16(vh23)); + float32x4_t vt3 = vcvt_f32_f16(vget_high_f16(vh23)); + float32x4_t vt4 = vcvt_f32_f16(vget_low_f16(vh45)); + float32x4_t vt5 = vcvt_f32_f16(vget_high_f16(vh45)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + vt2 = vmulq_f32(vt2, vt2); + vt3 = vmulq_f32(vt3, vt3); + vt4 = vmulq_f32(vt4, vt4); + vt5 = vmulq_f32(vt5, vt5); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + vacc0 = vaddq_f32(vacc0, vt2); + vacc0 = vaddq_f32(vacc0, vt3); + vacc0 = vaddq_f32(vacc0, vt4); + vacc0 = vaddq_f32(vacc0, vt5); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + } + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u32_acc8( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + float32x4_t vacc1 = vmovq_n_f32(0.0f); + float32x4_t vacc2 = vmovq_n_f32(0.0f); + float32x4_t vacc3 = vmovq_n_f32(0.0f); + float32x4_t vacc4 = vmovq_n_f32(0.0f); + float32x4_t vacc5 = vmovq_n_f32(0.0f); + float32x4_t vacc6 = vmovq_n_f32(0.0f); + float32x4_t vacc7 = vmovq_n_f32(0.0f); + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh23 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh45 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh67 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + float32x4_t vt2 = vcvt_f32_f16(vget_low_f16(vh23)); + float32x4_t vt3 = vcvt_f32_f16(vget_high_f16(vh23)); + float32x4_t vt4 = vcvt_f32_f16(vget_low_f16(vh45)); + float32x4_t vt5 = vcvt_f32_f16(vget_high_f16(vh45)); + float32x4_t vt6 = vcvt_f32_f16(vget_low_f16(vh67)); + float32x4_t vt7 = vcvt_f32_f16(vget_high_f16(vh67)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + vt2 = vmulq_f32(vt2, vt2); + vt3 = vmulq_f32(vt3, vt3); + vt4 = vmulq_f32(vt4, vt4); + vt5 = vmulq_f32(vt5, vt5); + vt6 = vmulq_f32(vt6, vt6); + vt7 = vmulq_f32(vt7, vt7); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + vacc2 = vaddq_f32(vacc2, vt2); + vacc3 = vaddq_f32(vacc3, vt3); + vacc4 = vaddq_f32(vacc4, vt4); + vacc5 = vaddq_f32(vacc5, vt5); + vacc6 = vaddq_f32(vacc6, vt6); + vacc7 = vaddq_f32(vacc7, vt7); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc2 = vaddq_f32(vacc2, vt0); + vacc3 = vaddq_f32(vacc3, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc4 = vaddq_f32(vacc4, vt0); + vacc5 = vaddq_f32(vacc5, vt1); + } + vacc0 = vaddq_f32(vacc0, vacc4); + vacc1 = vaddq_f32(vacc1, vacc5); + vacc2 = vaddq_f32(vacc2, vacc6); + vacc3 = vaddq_f32(vacc3, vacc7); + vacc0 = vaddq_f32(vacc0, vacc2); + vacc1 = vaddq_f32(vacc1, vacc3); + vacc0 = vaddq_f32(vacc0, vacc1); + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u32_acc4( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + float32x4_t vacc1 = vmovq_n_f32(0.0f); + float32x4_t vacc2 = vmovq_n_f32(0.0f); + float32x4_t vacc3 = vmovq_n_f32(0.0f); + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh23 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh45 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh67 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + float32x4_t vt2 = vcvt_f32_f16(vget_low_f16(vh23)); + float32x4_t vt3 = vcvt_f32_f16(vget_high_f16(vh23)); + float32x4_t vt4 = vcvt_f32_f16(vget_low_f16(vh45)); + float32x4_t vt5 = vcvt_f32_f16(vget_high_f16(vh45)); + float32x4_t vt6 = vcvt_f32_f16(vget_low_f16(vh67)); + float32x4_t vt7 = vcvt_f32_f16(vget_high_f16(vh67)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + vt2 = vmulq_f32(vt2, vt2); + vt3 = vmulq_f32(vt3, vt3); + vt4 = vmulq_f32(vt4, vt4); + vt5 = vmulq_f32(vt5, vt5); + vt6 = vmulq_f32(vt6, vt6); + vt7 = vmulq_f32(vt7, vt7); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + vacc2 = vaddq_f32(vacc2, vt2); + vacc3 = vaddq_f32(vacc3, vt3); + vacc0 = vaddq_f32(vacc0, vt4); + vacc1 = vaddq_f32(vacc1, vt5); + vacc2 = vaddq_f32(vacc2, vt6); + vacc3 = vaddq_f32(vacc3, vt7); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc2 = vaddq_f32(vacc2, vt0); + vacc3 = vaddq_f32(vacc3, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + vacc0 = vaddq_f32(vacc0, vacc2); + vacc1 = vaddq_f32(vacc1, vacc3); + vacc0 = vaddq_f32(vacc0, vacc1); + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u32_acc2( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + float32x4_t vacc1 = vmovq_n_f32(0.0f); + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh23 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh45 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh67 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + float32x4_t vt2 = vcvt_f32_f16(vget_low_f16(vh23)); + float32x4_t vt3 = vcvt_f32_f16(vget_high_f16(vh23)); + float32x4_t vt4 = vcvt_f32_f16(vget_low_f16(vh45)); + float32x4_t vt5 = vcvt_f32_f16(vget_high_f16(vh45)); + float32x4_t vt6 = vcvt_f32_f16(vget_low_f16(vh67)); + float32x4_t vt7 = vcvt_f32_f16(vget_high_f16(vh67)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + vt2 = vmulq_f32(vt2, vt2); + vt3 = vmulq_f32(vt3, vt3); + vt4 = vmulq_f32(vt4, vt4); + vt5 = vmulq_f32(vt5, vt5); + vt6 = vmulq_f32(vt6, vt6); + vt7 = vmulq_f32(vt7, vt7); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + vacc0 = vaddq_f32(vacc0, vt2); + vacc1 = vaddq_f32(vacc1, vt3); + vacc0 = vaddq_f32(vacc0, vt4); + vacc1 = vaddq_f32(vacc1, vt5); + vacc0 = vaddq_f32(vacc0, vt6); + vacc1 = vaddq_f32(vacc1, vt7); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc1 = vaddq_f32(vacc1, vt1); + } + vacc0 = vaddq_f32(vacc0, vacc1); + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} + +void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u32( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + float32x4_t vacc0 = vmovq_n_f32(0.0f); + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + const float16x8_t vh01 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh23 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh45 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + const float16x8_t vh67 = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh01)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh01)); + float32x4_t vt2 = vcvt_f32_f16(vget_low_f16(vh23)); + float32x4_t vt3 = vcvt_f32_f16(vget_high_f16(vh23)); + float32x4_t vt4 = vcvt_f32_f16(vget_low_f16(vh45)); + float32x4_t vt5 = vcvt_f32_f16(vget_high_f16(vh45)); + float32x4_t vt6 = vcvt_f32_f16(vget_low_f16(vh67)); + float32x4_t vt7 = vcvt_f32_f16(vget_high_f16(vh67)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + vt2 = vmulq_f32(vt2, vt2); + vt3 = vmulq_f32(vt3, vt3); + vt4 = vmulq_f32(vt4, vt4); + vt5 = vmulq_f32(vt5, vt5); + vt6 = vmulq_f32(vt6, vt6); + vt7 = vmulq_f32(vt7, vt7); + + vacc0 = vaddq_f32(vacc0, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + vacc0 = vaddq_f32(vacc0, vt2); + vacc0 = vaddq_f32(vacc0, vt3); + vacc0 = vaddq_f32(vacc0, vt4); + vacc0 = vaddq_f32(vacc0, vt5); + vacc0 = vaddq_f32(vacc0, vt6); + vacc0 = vaddq_f32(vacc0, vt7); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + } + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc0 = vaddq_f32(vacc0, vt0); + vacc0 = vaddq_f32(vacc0, vt1); + } + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; +} diff --git a/src/f16-f32acc-rsum2/neonfp16arith.c.in b/src/f16-f32acc-rsum2/neonfp16arith.c.in new file mode 100644 index 00000000000..0b31389a584 --- /dev/null +++ b/src/f16-f32acc-rsum2/neonfp16arith.c.in @@ -0,0 +1,102 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" + +$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(",")) +$SIMD_SIZE = BATCH_TILES[0] // 2 +$for BATCH_TILE in BATCH_TILES: + $assert BATCH_TILE % SIMD_SIZE == 0 + $assert BATCH_TILE >= SIMD_SIZE + $SIMD_TILE = BATCH_TILE // SIMD_SIZE + $ACCUMULATORS = SIMD_TILE + $while ACCUMULATORS: + $ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS + + void xnn_f16_f32acc_rsum2_ukernel__neonfp16arith_u${BATCH_TILE}${ACC_SUFFIX}( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input != NULL); + assert(output != NULL); + + const uint16_t* i = (const uint16_t*) input; + $for A in range(ACCUMULATORS): + float32x4_t vacc${A} = vmovq_n_f32(0.0f); + for (; batch >= ${BATCH_TILE} * sizeof(uint16_t); batch -= ${BATCH_TILE} * sizeof(uint16_t)) { + $for N in range(0, SIMD_TILE, 2): + const float16x8_t vh${ABC[N:N+2]} = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + $for N in range(0, SIMD_TILE, 2): + float32x4_t vt${N} = vcvt_f32_f16(vget_low_f16(vh${ABC[N:N+2]})); + float32x4_t vt${N+1} = vcvt_f32_f16(vget_high_f16(vh${ABC[N:N+2]})); + + $for N in range(((SIMD_TILE + 1)//2)*2): + vt${N} = vmulq_f32(vt${N}, vt${N}); + + $for N in range(((SIMD_TILE + 1)//2)*2): + vacc${N % ACCUMULATORS} = vaddq_f32(vacc${N % ACCUMULATORS}, vt${N}); + } + $for N in range(0, SIMD_TILE - 2, 2): + if (batch >= 8 * sizeof(uint16_t)) { + const float16x8_t vh = vreinterpretq_f16_u16(vld1q_u16(i)); + i += 8; + + float32x4_t vt0 = vcvt_f32_f16(vget_low_f16(vh)); + float32x4_t vt1 = vcvt_f32_f16(vget_high_f16(vh)); + + vt0 = vmulq_f32(vt0, vt0); + vt1 = vmulq_f32(vt1, vt1); + + batch -= 8 * sizeof(uint16_t); + vacc${N % ACCUMULATORS} = vaddq_f32(vacc${N % ACCUMULATORS}, vt0); + vacc${(N + 1) % ACCUMULATORS} = vaddq_f32(vacc${(N + 1) % ACCUMULATORS}, vt1); + } + $REDUCE_ACC = ACCUMULATORS + $while REDUCE_ACC > 1: + $for A in range(REDUCE_ACC % 2, (REDUCE_ACC + 1) // 2): + $if A + REDUCE_ACC // 2 <= REDUCE_ACC: + vacc${A} = vaddq_f32(vacc${A}, vacc${A + REDUCE_ACC // 2}); + $REDUCE_ACC = (REDUCE_ACC + 1) // 2 + const float32x2_t vscale = vdup_n_f32(params->scalar.scale); + if XNN_UNLIKELY(batch & (4 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_u16((const void*) i)); + i += 4; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc0 = vaddq_f32(vacc0, vt); + } + float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0)); + if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); + i += 2; + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vpadd_f32(vacc, vacc); + if XNN_UNLIKELY(batch & (1 * sizeof(uint16_t))) { + const float16x4_t vh = vreinterpret_f16_u16(vld1_dup_u16(i)); + float32x4_t vt = vcvt_f32_f16(vh); + vt = vmulq_f32(vt, vt); + vacc = vadd_f32(vacc, vget_low_f32(vt)); + } + vacc = vmul_f32(vacc, vscale); + + float vout = vget_lane_f32(vacc, 0); + *output += vout; + } + $ACCUMULATORS //= 2 diff --git a/src/f32-rdsum2/f32-rdsum2.inc b/src/f32-rdsum2/f32-rdsum2.inc new file mode 100644 index 00000000000..1a06de6dcf5 --- /dev/null +++ b/src/f32-rdsum2/f32-rdsum2.inc @@ -0,0 +1,42 @@ +// clang-format off +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +XNN_UKERNEL(xnn_arch_none, xnn_f32_rdsum2_ukernel_7p7x__scalar_u4, 7, 4, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rdsum2_ukernel_7p7x__neon_u16, 7, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rdsum2_ukernel_7p7x__neon_u32, 7, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rdsum2_ukernel_7p7x__neon_u64, 7, 64, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 +XNN_UKERNEL(xnn_arch_none, xnn_f32_rdsum2_ukernel_7p7x__sse2_u16, 7, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rdsum2_ukernel_7p7x__sse2_u32, 7, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rdsum2_ukernel_7p7x__sse2_u64, 7, 64, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rdsum2_ukernel_7p7x__avx_u16, 7, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rdsum2_ukernel_7p7x__avx_u32, 7, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rdsum2_ukernel_7p7x__avx_u64, 7, 64, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + +#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rdsum2_ukernel_7p7x__avx512f_u16, 7, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rdsum2_ukernel_7p7x__avx512f_u32, 7, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rdsum2_ukernel_7p7x__avx512f_u64, 7, 64, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rdsum2_ukernel_7p7x__avx512f_u128, 7, 128, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) + +#if XNN_ARCH_HEXAGON && XNN_ENABLE_HVX +XNN_UKERNEL(xnn_arch_hvx, xnn_f32_rdsum2_ukernel_7p7x__hvx_u32, 7, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_hvx, xnn_f32_rdsum2_ukernel_7p7x__hvx_u64, 7, 64, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_hvx, xnn_f32_rdsum2_ukernel_7p7x__hvx_u128, 7, 128, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +#endif // XNN_ARCH_HEXAGON && XNN_ENABLE_HVX + +#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD +XNN_UKERNEL(xnn_arch_none, xnn_f32_rdsum2_ukernel_7p7x__wasmsimd_u16, 7, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rdsum2_ukernel_7p7x__wasmsimd_u32, 7, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rdsum2_ukernel_7p7x__wasmsimd_u64, 7, 64, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + diff --git a/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx.c b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx.c new file mode 100644 index 00000000000..cbf462f3c7a --- /dev/null +++ b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx.c @@ -0,0 +1,994 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-avx.h" + + +void xnn_f32_rdsum2_ukernel_7p7x__avx_u16( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 16; channels -= 16) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[8]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[8]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[8]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[8]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[8]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[8]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[8]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 8; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + xnn_storeu_f32(output, vacc0); output += 8; + xnn_storeu_f32(output, vacc1); output += 8; + + input_row = (const float*) ((uintptr_t) input_row + 16 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[2]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 7; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*8]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*8]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*8]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*8]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*8]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*8]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*8]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*8], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*8], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*8], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*8], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*8], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*8], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*8], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[2]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = xnn_loadu_f32(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__avx_u32( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 32; channels -= 32) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[8]); + vin2 = xnn_loadu_f32(&i0[16]); + vin3 = xnn_loadu_f32(&i0[24]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[8]); + vin2 = xnn_loadu_f32(&i1[16]); + vin3 = xnn_loadu_f32(&i1[24]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[8]); + vin2 = xnn_loadu_f32(&i2[16]); + vin3 = xnn_loadu_f32(&i2[24]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[8]); + vin2 = xnn_loadu_f32(&i3[16]); + vin3 = xnn_loadu_f32(&i3[24]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[8]); + vin2 = xnn_loadu_f32(&i4[16]); + vin3 = xnn_loadu_f32(&i4[24]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[8]); + vin2 = xnn_loadu_f32(&i5[16]); + vin3 = xnn_loadu_f32(&i5[24]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[8]); + vin2 = xnn_loadu_f32(&i6[16]); + vin3 = xnn_loadu_f32(&i6[24]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 8; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + xnn_storeu_f32(output, vacc0); output += 8; + xnn_storeu_f32(output, vacc1); output += 8; + xnn_storeu_f32(output, vacc2); output += 8; + xnn_storeu_f32(output, vacc3); output += 8; + + input_row = (const float*) ((uintptr_t) input_row + 32 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[4]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 7; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*8]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*8]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*8]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*8]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*8]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*8]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*8]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*8], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*8], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*8], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*8], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*8], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*8], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*8], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = xnn_loadu_f32(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__avx_u64( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 64; channels -= 64) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + xnn_simd_f32_t vacc4 = xnn_zero_f32(); + xnn_simd_f32_t vacc5 = xnn_zero_f32(); + xnn_simd_f32_t vacc6 = xnn_zero_f32(); + xnn_simd_f32_t vacc7 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + xnn_simd_f32_t vin4; + xnn_simd_f32_t vin5; + xnn_simd_f32_t vin6; + xnn_simd_f32_t vin7; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[8]); + vin2 = xnn_loadu_f32(&i0[16]); + vin3 = xnn_loadu_f32(&i0[24]); + vin4 = xnn_loadu_f32(&i0[32]); + vin5 = xnn_loadu_f32(&i0[40]); + vin6 = xnn_loadu_f32(&i0[48]); + vin7 = xnn_loadu_f32(&i0[56]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[8]); + vin2 = xnn_loadu_f32(&i1[16]); + vin3 = xnn_loadu_f32(&i1[24]); + vin4 = xnn_loadu_f32(&i1[32]); + vin5 = xnn_loadu_f32(&i1[40]); + vin6 = xnn_loadu_f32(&i1[48]); + vin7 = xnn_loadu_f32(&i1[56]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[8]); + vin2 = xnn_loadu_f32(&i2[16]); + vin3 = xnn_loadu_f32(&i2[24]); + vin4 = xnn_loadu_f32(&i2[32]); + vin5 = xnn_loadu_f32(&i2[40]); + vin6 = xnn_loadu_f32(&i2[48]); + vin7 = xnn_loadu_f32(&i2[56]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[8]); + vin2 = xnn_loadu_f32(&i3[16]); + vin3 = xnn_loadu_f32(&i3[24]); + vin4 = xnn_loadu_f32(&i3[32]); + vin5 = xnn_loadu_f32(&i3[40]); + vin6 = xnn_loadu_f32(&i3[48]); + vin7 = xnn_loadu_f32(&i3[56]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[8]); + vin2 = xnn_loadu_f32(&i4[16]); + vin3 = xnn_loadu_f32(&i4[24]); + vin4 = xnn_loadu_f32(&i4[32]); + vin5 = xnn_loadu_f32(&i4[40]); + vin6 = xnn_loadu_f32(&i4[48]); + vin7 = xnn_loadu_f32(&i4[56]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[8]); + vin2 = xnn_loadu_f32(&i5[16]); + vin3 = xnn_loadu_f32(&i5[24]); + vin4 = xnn_loadu_f32(&i5[32]); + vin5 = xnn_loadu_f32(&i5[40]); + vin6 = xnn_loadu_f32(&i5[48]); + vin7 = xnn_loadu_f32(&i5[56]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[8]); + vin2 = xnn_loadu_f32(&i6[16]); + vin3 = xnn_loadu_f32(&i6[24]); + vin4 = xnn_loadu_f32(&i6[32]); + vin5 = xnn_loadu_f32(&i6[40]); + vin6 = xnn_loadu_f32(&i6[48]); + vin7 = xnn_loadu_f32(&i6[56]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + vacc4 = xnn_mul_f32(vacc4, vscale); + vacc5 = xnn_mul_f32(vacc5, vscale); + vacc6 = xnn_mul_f32(vacc6, vscale); + vacc7 = xnn_mul_f32(vacc7, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo4 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo5 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo6 = xnn_loadu_f32(o); o += 8; + xnn_simd_f32_t vo7 = xnn_loadu_f32(o); o += 8; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + vacc4 = xnn_add_f32(vo4, vacc4); + vacc5 = xnn_add_f32(vo5, vacc5); + vacc6 = xnn_add_f32(vo6, vacc6); + vacc7 = xnn_add_f32(vo7, vacc7); + xnn_storeu_f32(output, vacc0); output += 8; + xnn_storeu_f32(output, vacc1); output += 8; + xnn_storeu_f32(output, vacc2); output += 8; + xnn_storeu_f32(output, vacc3); output += 8; + xnn_storeu_f32(output, vacc4); output += 8; + xnn_storeu_f32(output, vacc5); output += 8; + xnn_storeu_f32(output, vacc6); output += 8; + xnn_storeu_f32(output, vacc7); output += 8; + + input_row = (const float*) ((uintptr_t) input_row + 64 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[8]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + vacc[4] = xnn_zero_f32(); + vacc[5] = xnn_zero_f32(); + vacc[6] = xnn_zero_f32(); + vacc[7] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 7; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*8]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*8]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*8]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*8]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*8]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*8]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*8]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*8], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*8], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*8], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*8], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*8], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*8], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*8], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[8]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = xnn_loadu_f32(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} diff --git a/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx512f.c b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx512f.c new file mode 100644 index 00000000000..f12280dda22 --- /dev/null +++ b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx512f.c @@ -0,0 +1,1216 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-avx512f.h" + + +void xnn_f32_rdsum2_ukernel_7p7x__avx512f_u16( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 16; channels -= 16) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + vin0 = xnn_loadu_f32(&i0[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i1[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i2[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i3[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i4[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i5[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i6[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 16; + vacc0 = xnn_add_f32(vo0, vacc0); + xnn_storeu_f32(output, vacc0); output += 16; + + input_row = (const float*) ((uintptr_t) input_row + 16 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[1]; + vacc[0] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 15; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*16]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*16]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*16]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*16]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*16]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*16]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*16], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*16], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*16], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*16], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*16], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*16], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*16], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[1]; + const float* o = output; + for (int i = 0; i < channels >> 4; ++i) { + vo[i] = xnn_loadu_f32(o); o += 16; + } + for (int i = 0; i < channels >> 4; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 4; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 16; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__avx512f_u32( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 32; channels -= 32) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 16; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + xnn_storeu_f32(output, vacc0); output += 16; + xnn_storeu_f32(output, vacc1); output += 16; + + input_row = (const float*) ((uintptr_t) input_row + 32 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[2]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 15; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*16]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*16]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*16]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*16]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*16]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*16]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*16], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*16], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*16], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*16], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*16], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*16], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*16], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[2]; + const float* o = output; + for (int i = 0; i < channels >> 4; ++i) { + vo[i] = xnn_loadu_f32(o); o += 16; + } + for (int i = 0; i < channels >> 4; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 4; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 16; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__avx512f_u64( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 64; channels -= 64) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[16]); + vin2 = xnn_loadu_f32(&i0[32]); + vin3 = xnn_loadu_f32(&i0[48]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[16]); + vin2 = xnn_loadu_f32(&i1[32]); + vin3 = xnn_loadu_f32(&i1[48]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[16]); + vin2 = xnn_loadu_f32(&i2[32]); + vin3 = xnn_loadu_f32(&i2[48]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[16]); + vin2 = xnn_loadu_f32(&i3[32]); + vin3 = xnn_loadu_f32(&i3[48]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[16]); + vin2 = xnn_loadu_f32(&i4[32]); + vin3 = xnn_loadu_f32(&i4[48]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[16]); + vin2 = xnn_loadu_f32(&i5[32]); + vin3 = xnn_loadu_f32(&i5[48]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[16]); + vin2 = xnn_loadu_f32(&i6[32]); + vin3 = xnn_loadu_f32(&i6[48]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 16; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + xnn_storeu_f32(output, vacc0); output += 16; + xnn_storeu_f32(output, vacc1); output += 16; + xnn_storeu_f32(output, vacc2); output += 16; + xnn_storeu_f32(output, vacc3); output += 16; + + input_row = (const float*) ((uintptr_t) input_row + 64 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[4]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 15; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*16]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*16]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*16]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*16]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*16]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*16]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*16], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*16], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*16], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*16], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*16], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*16], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*16], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 4; ++i) { + vo[i] = xnn_loadu_f32(o); o += 16; + } + for (int i = 0; i < channels >> 4; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 4; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 16; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__avx512f_u128( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 128; channels -= 128) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + xnn_simd_f32_t vacc4 = xnn_zero_f32(); + xnn_simd_f32_t vacc5 = xnn_zero_f32(); + xnn_simd_f32_t vacc6 = xnn_zero_f32(); + xnn_simd_f32_t vacc7 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + xnn_simd_f32_t vin4; + xnn_simd_f32_t vin5; + xnn_simd_f32_t vin6; + xnn_simd_f32_t vin7; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[16]); + vin2 = xnn_loadu_f32(&i0[32]); + vin3 = xnn_loadu_f32(&i0[48]); + vin4 = xnn_loadu_f32(&i0[64]); + vin5 = xnn_loadu_f32(&i0[80]); + vin6 = xnn_loadu_f32(&i0[96]); + vin7 = xnn_loadu_f32(&i0[112]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[16]); + vin2 = xnn_loadu_f32(&i1[32]); + vin3 = xnn_loadu_f32(&i1[48]); + vin4 = xnn_loadu_f32(&i1[64]); + vin5 = xnn_loadu_f32(&i1[80]); + vin6 = xnn_loadu_f32(&i1[96]); + vin7 = xnn_loadu_f32(&i1[112]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[16]); + vin2 = xnn_loadu_f32(&i2[32]); + vin3 = xnn_loadu_f32(&i2[48]); + vin4 = xnn_loadu_f32(&i2[64]); + vin5 = xnn_loadu_f32(&i2[80]); + vin6 = xnn_loadu_f32(&i2[96]); + vin7 = xnn_loadu_f32(&i2[112]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[16]); + vin2 = xnn_loadu_f32(&i3[32]); + vin3 = xnn_loadu_f32(&i3[48]); + vin4 = xnn_loadu_f32(&i3[64]); + vin5 = xnn_loadu_f32(&i3[80]); + vin6 = xnn_loadu_f32(&i3[96]); + vin7 = xnn_loadu_f32(&i3[112]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[16]); + vin2 = xnn_loadu_f32(&i4[32]); + vin3 = xnn_loadu_f32(&i4[48]); + vin4 = xnn_loadu_f32(&i4[64]); + vin5 = xnn_loadu_f32(&i4[80]); + vin6 = xnn_loadu_f32(&i4[96]); + vin7 = xnn_loadu_f32(&i4[112]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[16]); + vin2 = xnn_loadu_f32(&i5[32]); + vin3 = xnn_loadu_f32(&i5[48]); + vin4 = xnn_loadu_f32(&i5[64]); + vin5 = xnn_loadu_f32(&i5[80]); + vin6 = xnn_loadu_f32(&i5[96]); + vin7 = xnn_loadu_f32(&i5[112]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[16]); + vin2 = xnn_loadu_f32(&i6[32]); + vin3 = xnn_loadu_f32(&i6[48]); + vin4 = xnn_loadu_f32(&i6[64]); + vin5 = xnn_loadu_f32(&i6[80]); + vin6 = xnn_loadu_f32(&i6[96]); + vin7 = xnn_loadu_f32(&i6[112]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + vacc4 = xnn_mul_f32(vacc4, vscale); + vacc5 = xnn_mul_f32(vacc5, vscale); + vacc6 = xnn_mul_f32(vacc6, vscale); + vacc7 = xnn_mul_f32(vacc7, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo4 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo5 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo6 = xnn_loadu_f32(o); o += 16; + xnn_simd_f32_t vo7 = xnn_loadu_f32(o); o += 16; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + vacc4 = xnn_add_f32(vo4, vacc4); + vacc5 = xnn_add_f32(vo5, vacc5); + vacc6 = xnn_add_f32(vo6, vacc6); + vacc7 = xnn_add_f32(vo7, vacc7); + xnn_storeu_f32(output, vacc0); output += 16; + xnn_storeu_f32(output, vacc1); output += 16; + xnn_storeu_f32(output, vacc2); output += 16; + xnn_storeu_f32(output, vacc3); output += 16; + xnn_storeu_f32(output, vacc4); output += 16; + xnn_storeu_f32(output, vacc5); output += 16; + xnn_storeu_f32(output, vacc6); output += 16; + xnn_storeu_f32(output, vacc7); output += 16; + + input_row = (const float*) ((uintptr_t) input_row + 128 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[8]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + vacc[4] = xnn_zero_f32(); + vacc[5] = xnn_zero_f32(); + vacc[6] = xnn_zero_f32(); + vacc[7] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 15; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*16]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*16]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*16]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*16]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*16]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*16]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*16]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*16], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*16], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*16], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*16], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*16], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*16], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*16], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[8]; + const float* o = output; + for (int i = 0; i < channels >> 4; ++i) { + vo[i] = xnn_loadu_f32(o); o += 16; + } + for (int i = 0; i < channels >> 4; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 4; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 16; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} diff --git a/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-hvx.c b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-hvx.c new file mode 100644 index 00000000000..35a62f276ad --- /dev/null +++ b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-hvx.c @@ -0,0 +1,798 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-hvx.h" + + +void xnn_f32_rdsum2_ukernel_7p7x__hvx_u32( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 32; channels -= 32) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + vin0 = xnn_loadu_f32(&i0[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i1[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i2[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i3[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i4[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i5[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + vin0 = xnn_loadu_f32(&i6[0]); + vin0 = xnn_mul_f32(vin0, vin0); + vacc0 = xnn_add_f32(vin0, vacc0); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 32; + vacc0 = xnn_add_f32(vo0, vacc0); + xnn_storeu_f32(output, vacc0); output += 32; + + input_row = (const float*) ((uintptr_t) input_row + 32 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[1]; + vacc[0] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 5; + const size_t num_chunks = round_up_po2(channels, 32) >> 5; + const size_t remainder = channels & 31; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*32]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*32]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*32]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*32]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*32]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*32]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*32]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*32], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*32], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*32], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*32], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*32], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*32], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*32], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[1]; + const float* o = output; + for (int i = 0; i < channels >> 5; ++i) { + vo[i] = xnn_loadu_f32(o); o += 32; + } + for (int i = 0; i < channels >> 5; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 5; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 32; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__hvx_u64( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 64; channels -= 64) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[32]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[32]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[32]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[32]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[32]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[32]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[32]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 32; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 32; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + xnn_storeu_f32(output, vacc0); output += 32; + xnn_storeu_f32(output, vacc1); output += 32; + + input_row = (const float*) ((uintptr_t) input_row + 64 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[2]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 5; + const size_t num_chunks = round_up_po2(channels, 32) >> 5; + const size_t remainder = channels & 31; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*32]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*32]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*32]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*32]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*32]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*32]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*32]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*32], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*32], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*32], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*32], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*32], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*32], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*32], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[2]; + const float* o = output; + for (int i = 0; i < channels >> 5; ++i) { + vo[i] = xnn_loadu_f32(o); o += 32; + } + for (int i = 0; i < channels >> 5; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 5; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 32; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__hvx_u128( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 128; channels -= 128) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[32]); + vin2 = xnn_loadu_f32(&i0[64]); + vin3 = xnn_loadu_f32(&i0[96]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[32]); + vin2 = xnn_loadu_f32(&i1[64]); + vin3 = xnn_loadu_f32(&i1[96]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[32]); + vin2 = xnn_loadu_f32(&i2[64]); + vin3 = xnn_loadu_f32(&i2[96]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[32]); + vin2 = xnn_loadu_f32(&i3[64]); + vin3 = xnn_loadu_f32(&i3[96]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[32]); + vin2 = xnn_loadu_f32(&i4[64]); + vin3 = xnn_loadu_f32(&i4[96]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[32]); + vin2 = xnn_loadu_f32(&i5[64]); + vin3 = xnn_loadu_f32(&i5[96]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[32]); + vin2 = xnn_loadu_f32(&i6[64]); + vin3 = xnn_loadu_f32(&i6[96]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 32; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 32; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 32; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 32; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + xnn_storeu_f32(output, vacc0); output += 32; + xnn_storeu_f32(output, vacc1); output += 32; + xnn_storeu_f32(output, vacc2); output += 32; + xnn_storeu_f32(output, vacc3); output += 32; + + input_row = (const float*) ((uintptr_t) input_row + 128 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[4]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 5; + const size_t num_chunks = round_up_po2(channels, 32) >> 5; + const size_t remainder = channels & 31; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*32]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*32]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*32]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*32]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*32]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*32]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*32]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*32], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*32], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*32], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*32], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*32], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*32], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*32], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 5; ++i) { + vo[i] = xnn_loadu_f32(o); o += 32; + } + for (int i = 0; i < channels >> 5; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 5; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 32; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} diff --git a/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-neon.c b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-neon.c new file mode 100644 index 00000000000..5c9d1e80ce4 --- /dev/null +++ b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-neon.c @@ -0,0 +1,1386 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-neon.h" + + +void xnn_f32_rdsum2_ukernel_7p7x__neon_u16( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 16; channels -= 16) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[4]); + vin2 = xnn_loadu_f32(&i0[8]); + vin3 = xnn_loadu_f32(&i0[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[4]); + vin2 = xnn_loadu_f32(&i1[8]); + vin3 = xnn_loadu_f32(&i1[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[4]); + vin2 = xnn_loadu_f32(&i2[8]); + vin3 = xnn_loadu_f32(&i2[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[4]); + vin2 = xnn_loadu_f32(&i3[8]); + vin3 = xnn_loadu_f32(&i3[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[4]); + vin2 = xnn_loadu_f32(&i4[8]); + vin3 = xnn_loadu_f32(&i4[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[4]); + vin2 = xnn_loadu_f32(&i5[8]); + vin3 = xnn_loadu_f32(&i5[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[4]); + vin2 = xnn_loadu_f32(&i6[8]); + vin3 = xnn_loadu_f32(&i6[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 4; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + xnn_storeu_f32(output, vacc0); output += 4; + xnn_storeu_f32(output, vacc1); output += 4; + xnn_storeu_f32(output, vacc2); output += 4; + xnn_storeu_f32(output, vacc3); output += 4; + + input_row = (const float*) ((uintptr_t) input_row + 16 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[4]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 2; + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t remainder = channels & 3; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*4]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*4]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*4]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*4]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*4]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*4]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*4]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*4], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*4], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*4], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*4], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*4], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*4], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*4], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = xnn_loadu_f32(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 4; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__neon_u32( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 32; channels -= 32) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + xnn_simd_f32_t vacc4 = xnn_zero_f32(); + xnn_simd_f32_t vacc5 = xnn_zero_f32(); + xnn_simd_f32_t vacc6 = xnn_zero_f32(); + xnn_simd_f32_t vacc7 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + xnn_simd_f32_t vin4; + xnn_simd_f32_t vin5; + xnn_simd_f32_t vin6; + xnn_simd_f32_t vin7; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[4]); + vin2 = xnn_loadu_f32(&i0[8]); + vin3 = xnn_loadu_f32(&i0[12]); + vin4 = xnn_loadu_f32(&i0[16]); + vin5 = xnn_loadu_f32(&i0[20]); + vin6 = xnn_loadu_f32(&i0[24]); + vin7 = xnn_loadu_f32(&i0[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[4]); + vin2 = xnn_loadu_f32(&i1[8]); + vin3 = xnn_loadu_f32(&i1[12]); + vin4 = xnn_loadu_f32(&i1[16]); + vin5 = xnn_loadu_f32(&i1[20]); + vin6 = xnn_loadu_f32(&i1[24]); + vin7 = xnn_loadu_f32(&i1[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[4]); + vin2 = xnn_loadu_f32(&i2[8]); + vin3 = xnn_loadu_f32(&i2[12]); + vin4 = xnn_loadu_f32(&i2[16]); + vin5 = xnn_loadu_f32(&i2[20]); + vin6 = xnn_loadu_f32(&i2[24]); + vin7 = xnn_loadu_f32(&i2[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[4]); + vin2 = xnn_loadu_f32(&i3[8]); + vin3 = xnn_loadu_f32(&i3[12]); + vin4 = xnn_loadu_f32(&i3[16]); + vin5 = xnn_loadu_f32(&i3[20]); + vin6 = xnn_loadu_f32(&i3[24]); + vin7 = xnn_loadu_f32(&i3[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[4]); + vin2 = xnn_loadu_f32(&i4[8]); + vin3 = xnn_loadu_f32(&i4[12]); + vin4 = xnn_loadu_f32(&i4[16]); + vin5 = xnn_loadu_f32(&i4[20]); + vin6 = xnn_loadu_f32(&i4[24]); + vin7 = xnn_loadu_f32(&i4[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[4]); + vin2 = xnn_loadu_f32(&i5[8]); + vin3 = xnn_loadu_f32(&i5[12]); + vin4 = xnn_loadu_f32(&i5[16]); + vin5 = xnn_loadu_f32(&i5[20]); + vin6 = xnn_loadu_f32(&i5[24]); + vin7 = xnn_loadu_f32(&i5[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[4]); + vin2 = xnn_loadu_f32(&i6[8]); + vin3 = xnn_loadu_f32(&i6[12]); + vin4 = xnn_loadu_f32(&i6[16]); + vin5 = xnn_loadu_f32(&i6[20]); + vin6 = xnn_loadu_f32(&i6[24]); + vin7 = xnn_loadu_f32(&i6[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + vacc4 = xnn_mul_f32(vacc4, vscale); + vacc5 = xnn_mul_f32(vacc5, vscale); + vacc6 = xnn_mul_f32(vacc6, vscale); + vacc7 = xnn_mul_f32(vacc7, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo4 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo5 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo6 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo7 = xnn_loadu_f32(o); o += 4; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + vacc4 = xnn_add_f32(vo4, vacc4); + vacc5 = xnn_add_f32(vo5, vacc5); + vacc6 = xnn_add_f32(vo6, vacc6); + vacc7 = xnn_add_f32(vo7, vacc7); + xnn_storeu_f32(output, vacc0); output += 4; + xnn_storeu_f32(output, vacc1); output += 4; + xnn_storeu_f32(output, vacc2); output += 4; + xnn_storeu_f32(output, vacc3); output += 4; + xnn_storeu_f32(output, vacc4); output += 4; + xnn_storeu_f32(output, vacc5); output += 4; + xnn_storeu_f32(output, vacc6); output += 4; + xnn_storeu_f32(output, vacc7); output += 4; + + input_row = (const float*) ((uintptr_t) input_row + 32 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[8]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + vacc[4] = xnn_zero_f32(); + vacc[5] = xnn_zero_f32(); + vacc[6] = xnn_zero_f32(); + vacc[7] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 2; + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t remainder = channels & 3; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*4]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*4]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*4]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*4]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*4]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*4]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*4]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*4], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*4], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*4], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*4], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*4], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*4], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*4], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[8]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = xnn_loadu_f32(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 4; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__neon_u64( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 64; channels -= 64) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + xnn_simd_f32_t vacc4 = xnn_zero_f32(); + xnn_simd_f32_t vacc5 = xnn_zero_f32(); + xnn_simd_f32_t vacc6 = xnn_zero_f32(); + xnn_simd_f32_t vacc7 = xnn_zero_f32(); + xnn_simd_f32_t vacc8 = xnn_zero_f32(); + xnn_simd_f32_t vacc9 = xnn_zero_f32(); + xnn_simd_f32_t vacc10 = xnn_zero_f32(); + xnn_simd_f32_t vacc11 = xnn_zero_f32(); + xnn_simd_f32_t vacc12 = xnn_zero_f32(); + xnn_simd_f32_t vacc13 = xnn_zero_f32(); + xnn_simd_f32_t vacc14 = xnn_zero_f32(); + xnn_simd_f32_t vacc15 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + xnn_simd_f32_t vin4; + xnn_simd_f32_t vin5; + xnn_simd_f32_t vin6; + xnn_simd_f32_t vin7; + xnn_simd_f32_t vin8; + xnn_simd_f32_t vin9; + xnn_simd_f32_t vin10; + xnn_simd_f32_t vin11; + xnn_simd_f32_t vin12; + xnn_simd_f32_t vin13; + xnn_simd_f32_t vin14; + xnn_simd_f32_t vin15; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[4]); + vin2 = xnn_loadu_f32(&i0[8]); + vin3 = xnn_loadu_f32(&i0[12]); + vin4 = xnn_loadu_f32(&i0[16]); + vin5 = xnn_loadu_f32(&i0[20]); + vin6 = xnn_loadu_f32(&i0[24]); + vin7 = xnn_loadu_f32(&i0[28]); + vin8 = xnn_loadu_f32(&i0[32]); + vin9 = xnn_loadu_f32(&i0[36]); + vin10 = xnn_loadu_f32(&i0[40]); + vin11 = xnn_loadu_f32(&i0[44]); + vin12 = xnn_loadu_f32(&i0[48]); + vin13 = xnn_loadu_f32(&i0[52]); + vin14 = xnn_loadu_f32(&i0[56]); + vin15 = xnn_loadu_f32(&i0[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[4]); + vin2 = xnn_loadu_f32(&i1[8]); + vin3 = xnn_loadu_f32(&i1[12]); + vin4 = xnn_loadu_f32(&i1[16]); + vin5 = xnn_loadu_f32(&i1[20]); + vin6 = xnn_loadu_f32(&i1[24]); + vin7 = xnn_loadu_f32(&i1[28]); + vin8 = xnn_loadu_f32(&i1[32]); + vin9 = xnn_loadu_f32(&i1[36]); + vin10 = xnn_loadu_f32(&i1[40]); + vin11 = xnn_loadu_f32(&i1[44]); + vin12 = xnn_loadu_f32(&i1[48]); + vin13 = xnn_loadu_f32(&i1[52]); + vin14 = xnn_loadu_f32(&i1[56]); + vin15 = xnn_loadu_f32(&i1[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[4]); + vin2 = xnn_loadu_f32(&i2[8]); + vin3 = xnn_loadu_f32(&i2[12]); + vin4 = xnn_loadu_f32(&i2[16]); + vin5 = xnn_loadu_f32(&i2[20]); + vin6 = xnn_loadu_f32(&i2[24]); + vin7 = xnn_loadu_f32(&i2[28]); + vin8 = xnn_loadu_f32(&i2[32]); + vin9 = xnn_loadu_f32(&i2[36]); + vin10 = xnn_loadu_f32(&i2[40]); + vin11 = xnn_loadu_f32(&i2[44]); + vin12 = xnn_loadu_f32(&i2[48]); + vin13 = xnn_loadu_f32(&i2[52]); + vin14 = xnn_loadu_f32(&i2[56]); + vin15 = xnn_loadu_f32(&i2[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[4]); + vin2 = xnn_loadu_f32(&i3[8]); + vin3 = xnn_loadu_f32(&i3[12]); + vin4 = xnn_loadu_f32(&i3[16]); + vin5 = xnn_loadu_f32(&i3[20]); + vin6 = xnn_loadu_f32(&i3[24]); + vin7 = xnn_loadu_f32(&i3[28]); + vin8 = xnn_loadu_f32(&i3[32]); + vin9 = xnn_loadu_f32(&i3[36]); + vin10 = xnn_loadu_f32(&i3[40]); + vin11 = xnn_loadu_f32(&i3[44]); + vin12 = xnn_loadu_f32(&i3[48]); + vin13 = xnn_loadu_f32(&i3[52]); + vin14 = xnn_loadu_f32(&i3[56]); + vin15 = xnn_loadu_f32(&i3[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[4]); + vin2 = xnn_loadu_f32(&i4[8]); + vin3 = xnn_loadu_f32(&i4[12]); + vin4 = xnn_loadu_f32(&i4[16]); + vin5 = xnn_loadu_f32(&i4[20]); + vin6 = xnn_loadu_f32(&i4[24]); + vin7 = xnn_loadu_f32(&i4[28]); + vin8 = xnn_loadu_f32(&i4[32]); + vin9 = xnn_loadu_f32(&i4[36]); + vin10 = xnn_loadu_f32(&i4[40]); + vin11 = xnn_loadu_f32(&i4[44]); + vin12 = xnn_loadu_f32(&i4[48]); + vin13 = xnn_loadu_f32(&i4[52]); + vin14 = xnn_loadu_f32(&i4[56]); + vin15 = xnn_loadu_f32(&i4[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[4]); + vin2 = xnn_loadu_f32(&i5[8]); + vin3 = xnn_loadu_f32(&i5[12]); + vin4 = xnn_loadu_f32(&i5[16]); + vin5 = xnn_loadu_f32(&i5[20]); + vin6 = xnn_loadu_f32(&i5[24]); + vin7 = xnn_loadu_f32(&i5[28]); + vin8 = xnn_loadu_f32(&i5[32]); + vin9 = xnn_loadu_f32(&i5[36]); + vin10 = xnn_loadu_f32(&i5[40]); + vin11 = xnn_loadu_f32(&i5[44]); + vin12 = xnn_loadu_f32(&i5[48]); + vin13 = xnn_loadu_f32(&i5[52]); + vin14 = xnn_loadu_f32(&i5[56]); + vin15 = xnn_loadu_f32(&i5[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[4]); + vin2 = xnn_loadu_f32(&i6[8]); + vin3 = xnn_loadu_f32(&i6[12]); + vin4 = xnn_loadu_f32(&i6[16]); + vin5 = xnn_loadu_f32(&i6[20]); + vin6 = xnn_loadu_f32(&i6[24]); + vin7 = xnn_loadu_f32(&i6[28]); + vin8 = xnn_loadu_f32(&i6[32]); + vin9 = xnn_loadu_f32(&i6[36]); + vin10 = xnn_loadu_f32(&i6[40]); + vin11 = xnn_loadu_f32(&i6[44]); + vin12 = xnn_loadu_f32(&i6[48]); + vin13 = xnn_loadu_f32(&i6[52]); + vin14 = xnn_loadu_f32(&i6[56]); + vin15 = xnn_loadu_f32(&i6[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + vacc4 = xnn_mul_f32(vacc4, vscale); + vacc5 = xnn_mul_f32(vacc5, vscale); + vacc6 = xnn_mul_f32(vacc6, vscale); + vacc7 = xnn_mul_f32(vacc7, vscale); + vacc8 = xnn_mul_f32(vacc8, vscale); + vacc9 = xnn_mul_f32(vacc9, vscale); + vacc10 = xnn_mul_f32(vacc10, vscale); + vacc11 = xnn_mul_f32(vacc11, vscale); + vacc12 = xnn_mul_f32(vacc12, vscale); + vacc13 = xnn_mul_f32(vacc13, vscale); + vacc14 = xnn_mul_f32(vacc14, vscale); + vacc15 = xnn_mul_f32(vacc15, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo4 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo5 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo6 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo7 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo8 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo9 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo10 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo11 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo12 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo13 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo14 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo15 = xnn_loadu_f32(o); o += 4; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + vacc4 = xnn_add_f32(vo4, vacc4); + vacc5 = xnn_add_f32(vo5, vacc5); + vacc6 = xnn_add_f32(vo6, vacc6); + vacc7 = xnn_add_f32(vo7, vacc7); + vacc8 = xnn_add_f32(vo8, vacc8); + vacc9 = xnn_add_f32(vo9, vacc9); + vacc10 = xnn_add_f32(vo10, vacc10); + vacc11 = xnn_add_f32(vo11, vacc11); + vacc12 = xnn_add_f32(vo12, vacc12); + vacc13 = xnn_add_f32(vo13, vacc13); + vacc14 = xnn_add_f32(vo14, vacc14); + vacc15 = xnn_add_f32(vo15, vacc15); + xnn_storeu_f32(output, vacc0); output += 4; + xnn_storeu_f32(output, vacc1); output += 4; + xnn_storeu_f32(output, vacc2); output += 4; + xnn_storeu_f32(output, vacc3); output += 4; + xnn_storeu_f32(output, vacc4); output += 4; + xnn_storeu_f32(output, vacc5); output += 4; + xnn_storeu_f32(output, vacc6); output += 4; + xnn_storeu_f32(output, vacc7); output += 4; + xnn_storeu_f32(output, vacc8); output += 4; + xnn_storeu_f32(output, vacc9); output += 4; + xnn_storeu_f32(output, vacc10); output += 4; + xnn_storeu_f32(output, vacc11); output += 4; + xnn_storeu_f32(output, vacc12); output += 4; + xnn_storeu_f32(output, vacc13); output += 4; + xnn_storeu_f32(output, vacc14); output += 4; + xnn_storeu_f32(output, vacc15); output += 4; + + input_row = (const float*) ((uintptr_t) input_row + 64 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[16]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + vacc[4] = xnn_zero_f32(); + vacc[5] = xnn_zero_f32(); + vacc[6] = xnn_zero_f32(); + vacc[7] = xnn_zero_f32(); + vacc[8] = xnn_zero_f32(); + vacc[9] = xnn_zero_f32(); + vacc[10] = xnn_zero_f32(); + vacc[11] = xnn_zero_f32(); + vacc[12] = xnn_zero_f32(); + vacc[13] = xnn_zero_f32(); + vacc[14] = xnn_zero_f32(); + vacc[15] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 2; + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t remainder = channels & 3; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*4]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*4]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*4]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*4]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*4]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*4]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*4]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*4], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*4], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*4], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*4], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*4], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*4], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*4], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[16]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = xnn_loadu_f32(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 4; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} diff --git a/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-scalar.c b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-scalar.c new file mode 100644 index 00000000000..f4339280cf3 --- /dev/null +++ b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-scalar.c @@ -0,0 +1,326 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-scalar.h" + + +void xnn_f32_rdsum2_ukernel_7p7x__scalar_u4( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 4; channels -= 4) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[1]); + vin2 = xnn_loadu_f32(&i0[2]); + vin3 = xnn_loadu_f32(&i0[3]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[1]); + vin2 = xnn_loadu_f32(&i1[2]); + vin3 = xnn_loadu_f32(&i1[3]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[1]); + vin2 = xnn_loadu_f32(&i2[2]); + vin3 = xnn_loadu_f32(&i2[3]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[1]); + vin2 = xnn_loadu_f32(&i3[2]); + vin3 = xnn_loadu_f32(&i3[3]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[1]); + vin2 = xnn_loadu_f32(&i4[2]); + vin3 = xnn_loadu_f32(&i4[3]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[1]); + vin2 = xnn_loadu_f32(&i5[2]); + vin3 = xnn_loadu_f32(&i5[3]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[1]); + vin2 = xnn_loadu_f32(&i6[2]); + vin3 = xnn_loadu_f32(&i6[3]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 1; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 1; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 1; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 1; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + xnn_storeu_f32(output, vacc0); output += 1; + xnn_storeu_f32(output, vacc1); output += 1; + xnn_storeu_f32(output, vacc2); output += 1; + xnn_storeu_f32(output, vacc3); output += 1; + + input_row = (const float*) ((uintptr_t) input_row + 4 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[4]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 0; + const size_t num_chunks = round_up_po2(channels, 1) >> 0; + const size_t remainder = channels & 0; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*1]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*1]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*1]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*1]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*1]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*1]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*1]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*1], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*1], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*1], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*1], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*1], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*1], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*1], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 0; ++i) { + vo[i] = xnn_loadu_f32(o); o += 1; + } + for (int i = 0; i < channels >> 0; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 0; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 1; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} diff --git a/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-sse2.c b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-sse2.c new file mode 100644 index 00000000000..66a84992a43 --- /dev/null +++ b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-sse2.c @@ -0,0 +1,1386 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-sse2.h" + + +void xnn_f32_rdsum2_ukernel_7p7x__sse2_u16( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 16; channels -= 16) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[4]); + vin2 = xnn_loadu_f32(&i0[8]); + vin3 = xnn_loadu_f32(&i0[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[4]); + vin2 = xnn_loadu_f32(&i1[8]); + vin3 = xnn_loadu_f32(&i1[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[4]); + vin2 = xnn_loadu_f32(&i2[8]); + vin3 = xnn_loadu_f32(&i2[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[4]); + vin2 = xnn_loadu_f32(&i3[8]); + vin3 = xnn_loadu_f32(&i3[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[4]); + vin2 = xnn_loadu_f32(&i4[8]); + vin3 = xnn_loadu_f32(&i4[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[4]); + vin2 = xnn_loadu_f32(&i5[8]); + vin3 = xnn_loadu_f32(&i5[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[4]); + vin2 = xnn_loadu_f32(&i6[8]); + vin3 = xnn_loadu_f32(&i6[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 4; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + xnn_storeu_f32(output, vacc0); output += 4; + xnn_storeu_f32(output, vacc1); output += 4; + xnn_storeu_f32(output, vacc2); output += 4; + xnn_storeu_f32(output, vacc3); output += 4; + + input_row = (const float*) ((uintptr_t) input_row + 16 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[4]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 2; + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t remainder = channels & 3; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*4]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*4]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*4]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*4]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*4]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*4]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*4]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*4], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*4], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*4], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*4], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*4], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*4], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*4], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = xnn_loadu_f32(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 4; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__sse2_u32( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 32; channels -= 32) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + xnn_simd_f32_t vacc4 = xnn_zero_f32(); + xnn_simd_f32_t vacc5 = xnn_zero_f32(); + xnn_simd_f32_t vacc6 = xnn_zero_f32(); + xnn_simd_f32_t vacc7 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + xnn_simd_f32_t vin4; + xnn_simd_f32_t vin5; + xnn_simd_f32_t vin6; + xnn_simd_f32_t vin7; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[4]); + vin2 = xnn_loadu_f32(&i0[8]); + vin3 = xnn_loadu_f32(&i0[12]); + vin4 = xnn_loadu_f32(&i0[16]); + vin5 = xnn_loadu_f32(&i0[20]); + vin6 = xnn_loadu_f32(&i0[24]); + vin7 = xnn_loadu_f32(&i0[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[4]); + vin2 = xnn_loadu_f32(&i1[8]); + vin3 = xnn_loadu_f32(&i1[12]); + vin4 = xnn_loadu_f32(&i1[16]); + vin5 = xnn_loadu_f32(&i1[20]); + vin6 = xnn_loadu_f32(&i1[24]); + vin7 = xnn_loadu_f32(&i1[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[4]); + vin2 = xnn_loadu_f32(&i2[8]); + vin3 = xnn_loadu_f32(&i2[12]); + vin4 = xnn_loadu_f32(&i2[16]); + vin5 = xnn_loadu_f32(&i2[20]); + vin6 = xnn_loadu_f32(&i2[24]); + vin7 = xnn_loadu_f32(&i2[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[4]); + vin2 = xnn_loadu_f32(&i3[8]); + vin3 = xnn_loadu_f32(&i3[12]); + vin4 = xnn_loadu_f32(&i3[16]); + vin5 = xnn_loadu_f32(&i3[20]); + vin6 = xnn_loadu_f32(&i3[24]); + vin7 = xnn_loadu_f32(&i3[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[4]); + vin2 = xnn_loadu_f32(&i4[8]); + vin3 = xnn_loadu_f32(&i4[12]); + vin4 = xnn_loadu_f32(&i4[16]); + vin5 = xnn_loadu_f32(&i4[20]); + vin6 = xnn_loadu_f32(&i4[24]); + vin7 = xnn_loadu_f32(&i4[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[4]); + vin2 = xnn_loadu_f32(&i5[8]); + vin3 = xnn_loadu_f32(&i5[12]); + vin4 = xnn_loadu_f32(&i5[16]); + vin5 = xnn_loadu_f32(&i5[20]); + vin6 = xnn_loadu_f32(&i5[24]); + vin7 = xnn_loadu_f32(&i5[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[4]); + vin2 = xnn_loadu_f32(&i6[8]); + vin3 = xnn_loadu_f32(&i6[12]); + vin4 = xnn_loadu_f32(&i6[16]); + vin5 = xnn_loadu_f32(&i6[20]); + vin6 = xnn_loadu_f32(&i6[24]); + vin7 = xnn_loadu_f32(&i6[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + vacc4 = xnn_mul_f32(vacc4, vscale); + vacc5 = xnn_mul_f32(vacc5, vscale); + vacc6 = xnn_mul_f32(vacc6, vscale); + vacc7 = xnn_mul_f32(vacc7, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo4 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo5 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo6 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo7 = xnn_loadu_f32(o); o += 4; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + vacc4 = xnn_add_f32(vo4, vacc4); + vacc5 = xnn_add_f32(vo5, vacc5); + vacc6 = xnn_add_f32(vo6, vacc6); + vacc7 = xnn_add_f32(vo7, vacc7); + xnn_storeu_f32(output, vacc0); output += 4; + xnn_storeu_f32(output, vacc1); output += 4; + xnn_storeu_f32(output, vacc2); output += 4; + xnn_storeu_f32(output, vacc3); output += 4; + xnn_storeu_f32(output, vacc4); output += 4; + xnn_storeu_f32(output, vacc5); output += 4; + xnn_storeu_f32(output, vacc6); output += 4; + xnn_storeu_f32(output, vacc7); output += 4; + + input_row = (const float*) ((uintptr_t) input_row + 32 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[8]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + vacc[4] = xnn_zero_f32(); + vacc[5] = xnn_zero_f32(); + vacc[6] = xnn_zero_f32(); + vacc[7] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 2; + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t remainder = channels & 3; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*4]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*4]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*4]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*4]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*4]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*4]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*4]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*4], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*4], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*4], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*4], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*4], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*4], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*4], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[8]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = xnn_loadu_f32(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 4; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__sse2_u64( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 64; channels -= 64) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + xnn_simd_f32_t vacc4 = xnn_zero_f32(); + xnn_simd_f32_t vacc5 = xnn_zero_f32(); + xnn_simd_f32_t vacc6 = xnn_zero_f32(); + xnn_simd_f32_t vacc7 = xnn_zero_f32(); + xnn_simd_f32_t vacc8 = xnn_zero_f32(); + xnn_simd_f32_t vacc9 = xnn_zero_f32(); + xnn_simd_f32_t vacc10 = xnn_zero_f32(); + xnn_simd_f32_t vacc11 = xnn_zero_f32(); + xnn_simd_f32_t vacc12 = xnn_zero_f32(); + xnn_simd_f32_t vacc13 = xnn_zero_f32(); + xnn_simd_f32_t vacc14 = xnn_zero_f32(); + xnn_simd_f32_t vacc15 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + xnn_simd_f32_t vin4; + xnn_simd_f32_t vin5; + xnn_simd_f32_t vin6; + xnn_simd_f32_t vin7; + xnn_simd_f32_t vin8; + xnn_simd_f32_t vin9; + xnn_simd_f32_t vin10; + xnn_simd_f32_t vin11; + xnn_simd_f32_t vin12; + xnn_simd_f32_t vin13; + xnn_simd_f32_t vin14; + xnn_simd_f32_t vin15; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[4]); + vin2 = xnn_loadu_f32(&i0[8]); + vin3 = xnn_loadu_f32(&i0[12]); + vin4 = xnn_loadu_f32(&i0[16]); + vin5 = xnn_loadu_f32(&i0[20]); + vin6 = xnn_loadu_f32(&i0[24]); + vin7 = xnn_loadu_f32(&i0[28]); + vin8 = xnn_loadu_f32(&i0[32]); + vin9 = xnn_loadu_f32(&i0[36]); + vin10 = xnn_loadu_f32(&i0[40]); + vin11 = xnn_loadu_f32(&i0[44]); + vin12 = xnn_loadu_f32(&i0[48]); + vin13 = xnn_loadu_f32(&i0[52]); + vin14 = xnn_loadu_f32(&i0[56]); + vin15 = xnn_loadu_f32(&i0[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[4]); + vin2 = xnn_loadu_f32(&i1[8]); + vin3 = xnn_loadu_f32(&i1[12]); + vin4 = xnn_loadu_f32(&i1[16]); + vin5 = xnn_loadu_f32(&i1[20]); + vin6 = xnn_loadu_f32(&i1[24]); + vin7 = xnn_loadu_f32(&i1[28]); + vin8 = xnn_loadu_f32(&i1[32]); + vin9 = xnn_loadu_f32(&i1[36]); + vin10 = xnn_loadu_f32(&i1[40]); + vin11 = xnn_loadu_f32(&i1[44]); + vin12 = xnn_loadu_f32(&i1[48]); + vin13 = xnn_loadu_f32(&i1[52]); + vin14 = xnn_loadu_f32(&i1[56]); + vin15 = xnn_loadu_f32(&i1[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[4]); + vin2 = xnn_loadu_f32(&i2[8]); + vin3 = xnn_loadu_f32(&i2[12]); + vin4 = xnn_loadu_f32(&i2[16]); + vin5 = xnn_loadu_f32(&i2[20]); + vin6 = xnn_loadu_f32(&i2[24]); + vin7 = xnn_loadu_f32(&i2[28]); + vin8 = xnn_loadu_f32(&i2[32]); + vin9 = xnn_loadu_f32(&i2[36]); + vin10 = xnn_loadu_f32(&i2[40]); + vin11 = xnn_loadu_f32(&i2[44]); + vin12 = xnn_loadu_f32(&i2[48]); + vin13 = xnn_loadu_f32(&i2[52]); + vin14 = xnn_loadu_f32(&i2[56]); + vin15 = xnn_loadu_f32(&i2[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[4]); + vin2 = xnn_loadu_f32(&i3[8]); + vin3 = xnn_loadu_f32(&i3[12]); + vin4 = xnn_loadu_f32(&i3[16]); + vin5 = xnn_loadu_f32(&i3[20]); + vin6 = xnn_loadu_f32(&i3[24]); + vin7 = xnn_loadu_f32(&i3[28]); + vin8 = xnn_loadu_f32(&i3[32]); + vin9 = xnn_loadu_f32(&i3[36]); + vin10 = xnn_loadu_f32(&i3[40]); + vin11 = xnn_loadu_f32(&i3[44]); + vin12 = xnn_loadu_f32(&i3[48]); + vin13 = xnn_loadu_f32(&i3[52]); + vin14 = xnn_loadu_f32(&i3[56]); + vin15 = xnn_loadu_f32(&i3[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[4]); + vin2 = xnn_loadu_f32(&i4[8]); + vin3 = xnn_loadu_f32(&i4[12]); + vin4 = xnn_loadu_f32(&i4[16]); + vin5 = xnn_loadu_f32(&i4[20]); + vin6 = xnn_loadu_f32(&i4[24]); + vin7 = xnn_loadu_f32(&i4[28]); + vin8 = xnn_loadu_f32(&i4[32]); + vin9 = xnn_loadu_f32(&i4[36]); + vin10 = xnn_loadu_f32(&i4[40]); + vin11 = xnn_loadu_f32(&i4[44]); + vin12 = xnn_loadu_f32(&i4[48]); + vin13 = xnn_loadu_f32(&i4[52]); + vin14 = xnn_loadu_f32(&i4[56]); + vin15 = xnn_loadu_f32(&i4[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[4]); + vin2 = xnn_loadu_f32(&i5[8]); + vin3 = xnn_loadu_f32(&i5[12]); + vin4 = xnn_loadu_f32(&i5[16]); + vin5 = xnn_loadu_f32(&i5[20]); + vin6 = xnn_loadu_f32(&i5[24]); + vin7 = xnn_loadu_f32(&i5[28]); + vin8 = xnn_loadu_f32(&i5[32]); + vin9 = xnn_loadu_f32(&i5[36]); + vin10 = xnn_loadu_f32(&i5[40]); + vin11 = xnn_loadu_f32(&i5[44]); + vin12 = xnn_loadu_f32(&i5[48]); + vin13 = xnn_loadu_f32(&i5[52]); + vin14 = xnn_loadu_f32(&i5[56]); + vin15 = xnn_loadu_f32(&i5[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[4]); + vin2 = xnn_loadu_f32(&i6[8]); + vin3 = xnn_loadu_f32(&i6[12]); + vin4 = xnn_loadu_f32(&i6[16]); + vin5 = xnn_loadu_f32(&i6[20]); + vin6 = xnn_loadu_f32(&i6[24]); + vin7 = xnn_loadu_f32(&i6[28]); + vin8 = xnn_loadu_f32(&i6[32]); + vin9 = xnn_loadu_f32(&i6[36]); + vin10 = xnn_loadu_f32(&i6[40]); + vin11 = xnn_loadu_f32(&i6[44]); + vin12 = xnn_loadu_f32(&i6[48]); + vin13 = xnn_loadu_f32(&i6[52]); + vin14 = xnn_loadu_f32(&i6[56]); + vin15 = xnn_loadu_f32(&i6[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + vacc4 = xnn_mul_f32(vacc4, vscale); + vacc5 = xnn_mul_f32(vacc5, vscale); + vacc6 = xnn_mul_f32(vacc6, vscale); + vacc7 = xnn_mul_f32(vacc7, vscale); + vacc8 = xnn_mul_f32(vacc8, vscale); + vacc9 = xnn_mul_f32(vacc9, vscale); + vacc10 = xnn_mul_f32(vacc10, vscale); + vacc11 = xnn_mul_f32(vacc11, vscale); + vacc12 = xnn_mul_f32(vacc12, vscale); + vacc13 = xnn_mul_f32(vacc13, vscale); + vacc14 = xnn_mul_f32(vacc14, vscale); + vacc15 = xnn_mul_f32(vacc15, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo4 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo5 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo6 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo7 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo8 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo9 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo10 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo11 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo12 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo13 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo14 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo15 = xnn_loadu_f32(o); o += 4; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + vacc4 = xnn_add_f32(vo4, vacc4); + vacc5 = xnn_add_f32(vo5, vacc5); + vacc6 = xnn_add_f32(vo6, vacc6); + vacc7 = xnn_add_f32(vo7, vacc7); + vacc8 = xnn_add_f32(vo8, vacc8); + vacc9 = xnn_add_f32(vo9, vacc9); + vacc10 = xnn_add_f32(vo10, vacc10); + vacc11 = xnn_add_f32(vo11, vacc11); + vacc12 = xnn_add_f32(vo12, vacc12); + vacc13 = xnn_add_f32(vo13, vacc13); + vacc14 = xnn_add_f32(vo14, vacc14); + vacc15 = xnn_add_f32(vo15, vacc15); + xnn_storeu_f32(output, vacc0); output += 4; + xnn_storeu_f32(output, vacc1); output += 4; + xnn_storeu_f32(output, vacc2); output += 4; + xnn_storeu_f32(output, vacc3); output += 4; + xnn_storeu_f32(output, vacc4); output += 4; + xnn_storeu_f32(output, vacc5); output += 4; + xnn_storeu_f32(output, vacc6); output += 4; + xnn_storeu_f32(output, vacc7); output += 4; + xnn_storeu_f32(output, vacc8); output += 4; + xnn_storeu_f32(output, vacc9); output += 4; + xnn_storeu_f32(output, vacc10); output += 4; + xnn_storeu_f32(output, vacc11); output += 4; + xnn_storeu_f32(output, vacc12); output += 4; + xnn_storeu_f32(output, vacc13); output += 4; + xnn_storeu_f32(output, vacc14); output += 4; + xnn_storeu_f32(output, vacc15); output += 4; + + input_row = (const float*) ((uintptr_t) input_row + 64 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[16]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + vacc[4] = xnn_zero_f32(); + vacc[5] = xnn_zero_f32(); + vacc[6] = xnn_zero_f32(); + vacc[7] = xnn_zero_f32(); + vacc[8] = xnn_zero_f32(); + vacc[9] = xnn_zero_f32(); + vacc[10] = xnn_zero_f32(); + vacc[11] = xnn_zero_f32(); + vacc[12] = xnn_zero_f32(); + vacc[13] = xnn_zero_f32(); + vacc[14] = xnn_zero_f32(); + vacc[15] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 2; + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t remainder = channels & 3; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*4]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*4]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*4]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*4]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*4]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*4]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*4]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*4], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*4], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*4], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*4], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*4], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*4], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*4], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[16]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = xnn_loadu_f32(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 4; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} diff --git a/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-wasmsimd.c b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-wasmsimd.c new file mode 100644 index 00000000000..29c18577f0b --- /dev/null +++ b/src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-wasmsimd.c @@ -0,0 +1,1386 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rdsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-wasmsimd.h" + + +void xnn_f32_rdsum2_ukernel_7p7x__wasmsimd_u16( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 16; channels -= 16) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[4]); + vin2 = xnn_loadu_f32(&i0[8]); + vin3 = xnn_loadu_f32(&i0[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[4]); + vin2 = xnn_loadu_f32(&i1[8]); + vin3 = xnn_loadu_f32(&i1[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[4]); + vin2 = xnn_loadu_f32(&i2[8]); + vin3 = xnn_loadu_f32(&i2[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[4]); + vin2 = xnn_loadu_f32(&i3[8]); + vin3 = xnn_loadu_f32(&i3[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[4]); + vin2 = xnn_loadu_f32(&i4[8]); + vin3 = xnn_loadu_f32(&i4[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[4]); + vin2 = xnn_loadu_f32(&i5[8]); + vin3 = xnn_loadu_f32(&i5[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[4]); + vin2 = xnn_loadu_f32(&i6[8]); + vin3 = xnn_loadu_f32(&i6[12]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 4; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + xnn_storeu_f32(output, vacc0); output += 4; + xnn_storeu_f32(output, vacc1); output += 4; + xnn_storeu_f32(output, vacc2); output += 4; + xnn_storeu_f32(output, vacc3); output += 4; + + input_row = (const float*) ((uintptr_t) input_row + 16 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[4]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 2; + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t remainder = channels & 3; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*4]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*4]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*4]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*4]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*4]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*4]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*4]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*4], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*4], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*4], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*4], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*4], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*4], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*4], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = xnn_loadu_f32(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 4; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__wasmsimd_u32( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 32; channels -= 32) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + xnn_simd_f32_t vacc4 = xnn_zero_f32(); + xnn_simd_f32_t vacc5 = xnn_zero_f32(); + xnn_simd_f32_t vacc6 = xnn_zero_f32(); + xnn_simd_f32_t vacc7 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + xnn_simd_f32_t vin4; + xnn_simd_f32_t vin5; + xnn_simd_f32_t vin6; + xnn_simd_f32_t vin7; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[4]); + vin2 = xnn_loadu_f32(&i0[8]); + vin3 = xnn_loadu_f32(&i0[12]); + vin4 = xnn_loadu_f32(&i0[16]); + vin5 = xnn_loadu_f32(&i0[20]); + vin6 = xnn_loadu_f32(&i0[24]); + vin7 = xnn_loadu_f32(&i0[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[4]); + vin2 = xnn_loadu_f32(&i1[8]); + vin3 = xnn_loadu_f32(&i1[12]); + vin4 = xnn_loadu_f32(&i1[16]); + vin5 = xnn_loadu_f32(&i1[20]); + vin6 = xnn_loadu_f32(&i1[24]); + vin7 = xnn_loadu_f32(&i1[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[4]); + vin2 = xnn_loadu_f32(&i2[8]); + vin3 = xnn_loadu_f32(&i2[12]); + vin4 = xnn_loadu_f32(&i2[16]); + vin5 = xnn_loadu_f32(&i2[20]); + vin6 = xnn_loadu_f32(&i2[24]); + vin7 = xnn_loadu_f32(&i2[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[4]); + vin2 = xnn_loadu_f32(&i3[8]); + vin3 = xnn_loadu_f32(&i3[12]); + vin4 = xnn_loadu_f32(&i3[16]); + vin5 = xnn_loadu_f32(&i3[20]); + vin6 = xnn_loadu_f32(&i3[24]); + vin7 = xnn_loadu_f32(&i3[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[4]); + vin2 = xnn_loadu_f32(&i4[8]); + vin3 = xnn_loadu_f32(&i4[12]); + vin4 = xnn_loadu_f32(&i4[16]); + vin5 = xnn_loadu_f32(&i4[20]); + vin6 = xnn_loadu_f32(&i4[24]); + vin7 = xnn_loadu_f32(&i4[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[4]); + vin2 = xnn_loadu_f32(&i5[8]); + vin3 = xnn_loadu_f32(&i5[12]); + vin4 = xnn_loadu_f32(&i5[16]); + vin5 = xnn_loadu_f32(&i5[20]); + vin6 = xnn_loadu_f32(&i5[24]); + vin7 = xnn_loadu_f32(&i5[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[4]); + vin2 = xnn_loadu_f32(&i6[8]); + vin3 = xnn_loadu_f32(&i6[12]); + vin4 = xnn_loadu_f32(&i6[16]); + vin5 = xnn_loadu_f32(&i6[20]); + vin6 = xnn_loadu_f32(&i6[24]); + vin7 = xnn_loadu_f32(&i6[28]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + vacc4 = xnn_mul_f32(vacc4, vscale); + vacc5 = xnn_mul_f32(vacc5, vscale); + vacc6 = xnn_mul_f32(vacc6, vscale); + vacc7 = xnn_mul_f32(vacc7, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo4 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo5 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo6 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo7 = xnn_loadu_f32(o); o += 4; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + vacc4 = xnn_add_f32(vo4, vacc4); + vacc5 = xnn_add_f32(vo5, vacc5); + vacc6 = xnn_add_f32(vo6, vacc6); + vacc7 = xnn_add_f32(vo7, vacc7); + xnn_storeu_f32(output, vacc0); output += 4; + xnn_storeu_f32(output, vacc1); output += 4; + xnn_storeu_f32(output, vacc2); output += 4; + xnn_storeu_f32(output, vacc3); output += 4; + xnn_storeu_f32(output, vacc4); output += 4; + xnn_storeu_f32(output, vacc5); output += 4; + xnn_storeu_f32(output, vacc6); output += 4; + xnn_storeu_f32(output, vacc7); output += 4; + + input_row = (const float*) ((uintptr_t) input_row + 32 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[8]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + vacc[4] = xnn_zero_f32(); + vacc[5] = xnn_zero_f32(); + vacc[6] = xnn_zero_f32(); + vacc[7] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 2; + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t remainder = channels & 3; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*4]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*4]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*4]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*4]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*4]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*4]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*4]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*4], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*4], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*4], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*4], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*4], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*4], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*4], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[8]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = xnn_loadu_f32(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 4; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} + +void xnn_f32_rdsum2_ukernel_7p7x__wasmsimd_u64( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = 7 * input_stride1; + for (; channels >= 64; channels -= 64) { + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + xnn_simd_f32_t vacc4 = xnn_zero_f32(); + xnn_simd_f32_t vacc5 = xnn_zero_f32(); + xnn_simd_f32_t vacc6 = xnn_zero_f32(); + xnn_simd_f32_t vacc7 = xnn_zero_f32(); + xnn_simd_f32_t vacc8 = xnn_zero_f32(); + xnn_simd_f32_t vacc9 = xnn_zero_f32(); + xnn_simd_f32_t vacc10 = xnn_zero_f32(); + xnn_simd_f32_t vacc11 = xnn_zero_f32(); + xnn_simd_f32_t vacc12 = xnn_zero_f32(); + xnn_simd_f32_t vacc13 = xnn_zero_f32(); + xnn_simd_f32_t vacc14 = xnn_zero_f32(); + xnn_simd_f32_t vacc15 = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + xnn_simd_f32_t vin0; + xnn_simd_f32_t vin1; + xnn_simd_f32_t vin2; + xnn_simd_f32_t vin3; + xnn_simd_f32_t vin4; + xnn_simd_f32_t vin5; + xnn_simd_f32_t vin6; + xnn_simd_f32_t vin7; + xnn_simd_f32_t vin8; + xnn_simd_f32_t vin9; + xnn_simd_f32_t vin10; + xnn_simd_f32_t vin11; + xnn_simd_f32_t vin12; + xnn_simd_f32_t vin13; + xnn_simd_f32_t vin14; + xnn_simd_f32_t vin15; + vin0 = xnn_loadu_f32(&i0[0]); + vin1 = xnn_loadu_f32(&i0[4]); + vin2 = xnn_loadu_f32(&i0[8]); + vin3 = xnn_loadu_f32(&i0[12]); + vin4 = xnn_loadu_f32(&i0[16]); + vin5 = xnn_loadu_f32(&i0[20]); + vin6 = xnn_loadu_f32(&i0[24]); + vin7 = xnn_loadu_f32(&i0[28]); + vin8 = xnn_loadu_f32(&i0[32]); + vin9 = xnn_loadu_f32(&i0[36]); + vin10 = xnn_loadu_f32(&i0[40]); + vin11 = xnn_loadu_f32(&i0[44]); + vin12 = xnn_loadu_f32(&i0[48]); + vin13 = xnn_loadu_f32(&i0[52]); + vin14 = xnn_loadu_f32(&i0[56]); + vin15 = xnn_loadu_f32(&i0[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i1[0]); + vin1 = xnn_loadu_f32(&i1[4]); + vin2 = xnn_loadu_f32(&i1[8]); + vin3 = xnn_loadu_f32(&i1[12]); + vin4 = xnn_loadu_f32(&i1[16]); + vin5 = xnn_loadu_f32(&i1[20]); + vin6 = xnn_loadu_f32(&i1[24]); + vin7 = xnn_loadu_f32(&i1[28]); + vin8 = xnn_loadu_f32(&i1[32]); + vin9 = xnn_loadu_f32(&i1[36]); + vin10 = xnn_loadu_f32(&i1[40]); + vin11 = xnn_loadu_f32(&i1[44]); + vin12 = xnn_loadu_f32(&i1[48]); + vin13 = xnn_loadu_f32(&i1[52]); + vin14 = xnn_loadu_f32(&i1[56]); + vin15 = xnn_loadu_f32(&i1[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i2[0]); + vin1 = xnn_loadu_f32(&i2[4]); + vin2 = xnn_loadu_f32(&i2[8]); + vin3 = xnn_loadu_f32(&i2[12]); + vin4 = xnn_loadu_f32(&i2[16]); + vin5 = xnn_loadu_f32(&i2[20]); + vin6 = xnn_loadu_f32(&i2[24]); + vin7 = xnn_loadu_f32(&i2[28]); + vin8 = xnn_loadu_f32(&i2[32]); + vin9 = xnn_loadu_f32(&i2[36]); + vin10 = xnn_loadu_f32(&i2[40]); + vin11 = xnn_loadu_f32(&i2[44]); + vin12 = xnn_loadu_f32(&i2[48]); + vin13 = xnn_loadu_f32(&i2[52]); + vin14 = xnn_loadu_f32(&i2[56]); + vin15 = xnn_loadu_f32(&i2[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i3[0]); + vin1 = xnn_loadu_f32(&i3[4]); + vin2 = xnn_loadu_f32(&i3[8]); + vin3 = xnn_loadu_f32(&i3[12]); + vin4 = xnn_loadu_f32(&i3[16]); + vin5 = xnn_loadu_f32(&i3[20]); + vin6 = xnn_loadu_f32(&i3[24]); + vin7 = xnn_loadu_f32(&i3[28]); + vin8 = xnn_loadu_f32(&i3[32]); + vin9 = xnn_loadu_f32(&i3[36]); + vin10 = xnn_loadu_f32(&i3[40]); + vin11 = xnn_loadu_f32(&i3[44]); + vin12 = xnn_loadu_f32(&i3[48]); + vin13 = xnn_loadu_f32(&i3[52]); + vin14 = xnn_loadu_f32(&i3[56]); + vin15 = xnn_loadu_f32(&i3[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i4[0]); + vin1 = xnn_loadu_f32(&i4[4]); + vin2 = xnn_loadu_f32(&i4[8]); + vin3 = xnn_loadu_f32(&i4[12]); + vin4 = xnn_loadu_f32(&i4[16]); + vin5 = xnn_loadu_f32(&i4[20]); + vin6 = xnn_loadu_f32(&i4[24]); + vin7 = xnn_loadu_f32(&i4[28]); + vin8 = xnn_loadu_f32(&i4[32]); + vin9 = xnn_loadu_f32(&i4[36]); + vin10 = xnn_loadu_f32(&i4[40]); + vin11 = xnn_loadu_f32(&i4[44]); + vin12 = xnn_loadu_f32(&i4[48]); + vin13 = xnn_loadu_f32(&i4[52]); + vin14 = xnn_loadu_f32(&i4[56]); + vin15 = xnn_loadu_f32(&i4[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i5[0]); + vin1 = xnn_loadu_f32(&i5[4]); + vin2 = xnn_loadu_f32(&i5[8]); + vin3 = xnn_loadu_f32(&i5[12]); + vin4 = xnn_loadu_f32(&i5[16]); + vin5 = xnn_loadu_f32(&i5[20]); + vin6 = xnn_loadu_f32(&i5[24]); + vin7 = xnn_loadu_f32(&i5[28]); + vin8 = xnn_loadu_f32(&i5[32]); + vin9 = xnn_loadu_f32(&i5[36]); + vin10 = xnn_loadu_f32(&i5[40]); + vin11 = xnn_loadu_f32(&i5[44]); + vin12 = xnn_loadu_f32(&i5[48]); + vin13 = xnn_loadu_f32(&i5[52]); + vin14 = xnn_loadu_f32(&i5[56]); + vin15 = xnn_loadu_f32(&i5[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + vin0 = xnn_loadu_f32(&i6[0]); + vin1 = xnn_loadu_f32(&i6[4]); + vin2 = xnn_loadu_f32(&i6[8]); + vin3 = xnn_loadu_f32(&i6[12]); + vin4 = xnn_loadu_f32(&i6[16]); + vin5 = xnn_loadu_f32(&i6[20]); + vin6 = xnn_loadu_f32(&i6[24]); + vin7 = xnn_loadu_f32(&i6[28]); + vin8 = xnn_loadu_f32(&i6[32]); + vin9 = xnn_loadu_f32(&i6[36]); + vin10 = xnn_loadu_f32(&i6[40]); + vin11 = xnn_loadu_f32(&i6[44]); + vin12 = xnn_loadu_f32(&i6[48]); + vin13 = xnn_loadu_f32(&i6[52]); + vin14 = xnn_loadu_f32(&i6[56]); + vin15 = xnn_loadu_f32(&i6[60]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vin7 = xnn_mul_f32(vin7, vin7); + vin8 = xnn_mul_f32(vin8, vin8); + vin9 = xnn_mul_f32(vin9, vin9); + vin10 = xnn_mul_f32(vin10, vin10); + vin11 = xnn_mul_f32(vin11, vin11); + vin12 = xnn_mul_f32(vin12, vin12); + vin13 = xnn_mul_f32(vin13, vin13); + vin14 = xnn_mul_f32(vin14, vin14); + vin15 = xnn_mul_f32(vin15, vin15); + vacc0 = xnn_add_f32(vin0, vacc0); + vacc1 = xnn_add_f32(vin1, vacc1); + vacc2 = xnn_add_f32(vin2, vacc2); + vacc3 = xnn_add_f32(vin3, vacc3); + vacc4 = xnn_add_f32(vin4, vacc4); + vacc5 = xnn_add_f32(vin5, vacc5); + vacc6 = xnn_add_f32(vin6, vacc6); + vacc7 = xnn_add_f32(vin7, vacc7); + vacc8 = xnn_add_f32(vin8, vacc8); + vacc9 = xnn_add_f32(vin9, vacc9); + vacc10 = xnn_add_f32(vin10, vacc10); + vacc11 = xnn_add_f32(vin11, vacc11); + vacc12 = xnn_add_f32(vin12, vacc12); + vacc13 = xnn_add_f32(vin13, vacc13); + vacc14 = xnn_add_f32(vin14, vacc14); + vacc15 = xnn_add_f32(vin15, vacc15); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = xnn_mul_f32(vacc0, vscale); + vacc1 = xnn_mul_f32(vacc1, vscale); + vacc2 = xnn_mul_f32(vacc2, vscale); + vacc3 = xnn_mul_f32(vacc3, vscale); + vacc4 = xnn_mul_f32(vacc4, vscale); + vacc5 = xnn_mul_f32(vacc5, vscale); + vacc6 = xnn_mul_f32(vacc6, vscale); + vacc7 = xnn_mul_f32(vacc7, vscale); + vacc8 = xnn_mul_f32(vacc8, vscale); + vacc9 = xnn_mul_f32(vacc9, vscale); + vacc10 = xnn_mul_f32(vacc10, vscale); + vacc11 = xnn_mul_f32(vacc11, vscale); + vacc12 = xnn_mul_f32(vacc12, vscale); + vacc13 = xnn_mul_f32(vacc13, vscale); + vacc14 = xnn_mul_f32(vacc14, vscale); + vacc15 = xnn_mul_f32(vacc15, vscale); + + const float* o = output; + xnn_simd_f32_t vo0 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo1 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo2 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo3 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo4 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo5 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo6 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo7 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo8 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo9 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo10 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo11 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo12 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo13 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo14 = xnn_loadu_f32(o); o += 4; + xnn_simd_f32_t vo15 = xnn_loadu_f32(o); o += 4; + vacc0 = xnn_add_f32(vo0, vacc0); + vacc1 = xnn_add_f32(vo1, vacc1); + vacc2 = xnn_add_f32(vo2, vacc2); + vacc3 = xnn_add_f32(vo3, vacc3); + vacc4 = xnn_add_f32(vo4, vacc4); + vacc5 = xnn_add_f32(vo5, vacc5); + vacc6 = xnn_add_f32(vo6, vacc6); + vacc7 = xnn_add_f32(vo7, vacc7); + vacc8 = xnn_add_f32(vo8, vacc8); + vacc9 = xnn_add_f32(vo9, vacc9); + vacc10 = xnn_add_f32(vo10, vacc10); + vacc11 = xnn_add_f32(vo11, vacc11); + vacc12 = xnn_add_f32(vo12, vacc12); + vacc13 = xnn_add_f32(vo13, vacc13); + vacc14 = xnn_add_f32(vo14, vacc14); + vacc15 = xnn_add_f32(vo15, vacc15); + xnn_storeu_f32(output, vacc0); output += 4; + xnn_storeu_f32(output, vacc1); output += 4; + xnn_storeu_f32(output, vacc2); output += 4; + xnn_storeu_f32(output, vacc3); output += 4; + xnn_storeu_f32(output, vacc4); output += 4; + xnn_storeu_f32(output, vacc5); output += 4; + xnn_storeu_f32(output, vacc6); output += 4; + xnn_storeu_f32(output, vacc7); output += 4; + xnn_storeu_f32(output, vacc8); output += 4; + xnn_storeu_f32(output, vacc9); output += 4; + xnn_storeu_f32(output, vacc10); output += 4; + xnn_storeu_f32(output, vacc11); output += 4; + xnn_storeu_f32(output, vacc12); output += 4; + xnn_storeu_f32(output, vacc13); output += 4; + xnn_storeu_f32(output, vacc14); output += 4; + xnn_storeu_f32(output, vacc15); output += 4; + + input_row = (const float*) ((uintptr_t) input_row + 64 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride1; + const float* i0 = input_row; + const float* i1 = + (const float*) ((uintptr_t) input_row + 1 * input_stride1); + const float* i2 = + (const float*) ((uintptr_t) input_row + 2 * input_stride1); + const float* i3 = + (const float*) ((uintptr_t) input_row + 3 * input_stride1); + const float* i4 = + (const float*) ((uintptr_t) input_row + 4 * input_stride1); + const float* i5 = + (const float*) ((uintptr_t) input_row + 5 * input_stride1); + const float* i6 = + (const float*) ((uintptr_t) input_row + 6 * input_stride1); + xnn_simd_f32_t vacc[16]; + vacc[0] = xnn_zero_f32(); + vacc[1] = xnn_zero_f32(); + vacc[2] = xnn_zero_f32(); + vacc[3] = xnn_zero_f32(); + vacc[4] = xnn_zero_f32(); + vacc[5] = xnn_zero_f32(); + vacc[6] = xnn_zero_f32(); + vacc[7] = xnn_zero_f32(); + vacc[8] = xnn_zero_f32(); + vacc[9] = xnn_zero_f32(); + vacc[10] = xnn_zero_f32(); + vacc[11] = xnn_zero_f32(); + vacc[12] = xnn_zero_f32(); + vacc[13] = xnn_zero_f32(); + vacc[14] = xnn_zero_f32(); + vacc[15] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> 2; + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t remainder = channels & 3; + for (int r = k1; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + xnn_simd_f32_t vin0 = xnn_loadu_f32(&i0[i*4]); + xnn_simd_f32_t vin1 = xnn_loadu_f32(&i1[i*4]); + xnn_simd_f32_t vin2 = xnn_loadu_f32(&i2[i*4]); + xnn_simd_f32_t vin3 = xnn_loadu_f32(&i3[i*4]); + xnn_simd_f32_t vin4 = xnn_loadu_f32(&i4[i*4]); + xnn_simd_f32_t vin5 = xnn_loadu_f32(&i5[i*4]); + xnn_simd_f32_t vin6 = xnn_loadu_f32(&i6[i*4]); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[i] = xnn_add_f32(vin0, vacc[i]); + vacc[i] = xnn_add_f32(vin1, vacc[i]); + vacc[i] = xnn_add_f32(vin2, vacc[i]); + vacc[i] = xnn_add_f32(vin3, vacc[i]); + vacc[i] = xnn_add_f32(vin4, vacc[i]); + vacc[i] = xnn_add_f32(vin5, vacc[i]); + vacc[i] = xnn_add_f32(vin6, vacc[i]); + } + + if (remainder) { + xnn_simd_f32_t vin0 = xnn_load_tail_f32(&i0[num_full_chunks*4], remainder); + xnn_simd_f32_t vin1 = xnn_load_tail_f32(&i1[num_full_chunks*4], remainder); + xnn_simd_f32_t vin2 = xnn_load_tail_f32(&i2[num_full_chunks*4], remainder); + xnn_simd_f32_t vin3 = xnn_load_tail_f32(&i3[num_full_chunks*4], remainder); + xnn_simd_f32_t vin4 = xnn_load_tail_f32(&i4[num_full_chunks*4], remainder); + xnn_simd_f32_t vin5 = xnn_load_tail_f32(&i5[num_full_chunks*4], remainder); + xnn_simd_f32_t vin6 = xnn_load_tail_f32(&i6[num_full_chunks*4], remainder); + vin0 = xnn_mul_f32(vin0, vin0); + vin1 = xnn_mul_f32(vin1, vin1); + vin2 = xnn_mul_f32(vin2, vin2); + vin3 = xnn_mul_f32(vin3, vin3); + vin4 = xnn_mul_f32(vin4, vin4); + vin5 = xnn_mul_f32(vin5, vin5); + vin6 = xnn_mul_f32(vin6, vin6); + vacc[num_full_chunks] = xnn_add_f32(vin0, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin1, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin2, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin3, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin4, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin5, vacc[num_full_chunks]); + vacc[num_full_chunks] = xnn_add_f32(vin6, vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[16]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = xnn_loadu_f32(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + xnn_storeu_f32(output, vacc[i]); output += 4; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } +} diff --git a/src/f32-rdsum2/simd.c.in b/src/f32-rdsum2/simd.c.in new file mode 100644 index 00000000000..8dec308a0c7 --- /dev/null +++ b/src/f32-rdsum2/simd.c.in @@ -0,0 +1,155 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-${ARCH}.h" + +$import math +$LOG2_SIMD_SIZE = int(math.log2(SIMD_SIZE)) +$CHANNELS_LIST = tuple(int(c) for c in str(CHANNELS).split(",")) +$for CHANNELS in $CHANNELS_LIST: + $UNROLL = CHANNELS >> LOG2_SIMD_SIZE + + void xnn_f32_rdsum2_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__${ARCH}_u${CHANNELS}( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(k1 != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const xnn_simd_f32_t vscale = xnn_set1_f32(params->scalar.scale); + float* original_output = output; + size_t original_channels = channels; + + for (size_t k = 0; k < k3; ++k) { + for (size_t j = 0; j < k2; ++j) { + const float* input_row = + (const float*)((uintptr_t)input + j * input_stride2 + + k * input_stride3); + output = original_output; + channels = original_channels; + + assert(input_row != NULL); + + size_t input_increment = ${ACCUMULATORS} * input_stride1; + for (; channels >= ${CHANNELS}; channels -= ${CHANNELS}) { + const float* i0 = input_row; + $for i in range(1, ACCUMULATORS): + const float* i${i} = + (const float*) ((uintptr_t) input_row + ${i} * input_stride1); + + $for i in range(UNROLL): + xnn_simd_f32_t vacc${i} = xnn_zero_f32(); + + for (int r = k1; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = zero; + } + $for c in range(UNROLL): + xnn_simd_f32_t vin${c}; + $for j in range(ACCUMULATORS): + $for c in range(UNROLL): + vin${c} = xnn_loadu_f32(&i${j}[${c*SIMD_SIZE}]); + $for c in range(UNROLL): + vin${c} = xnn_mul_f32(vin${c}, vin${c}); + $for c in range(UNROLL): + vacc${c} = xnn_add_f32(vin${c}, vacc${c}); + $for N in range(0, ACCUMULATORS): + i${N} = (const float*) ((uintptr_t) i${N} + input_increment); + } + $for i in range(UNROLL): + vacc${i} = xnn_mul_f32(vacc${i}, vscale); + + const float* o = output; + $for i in range(0, UNROLL): + xnn_simd_f32_t vo${i} = xnn_loadu_f32(o); o += ${SIMD_SIZE}; + $for i in range(0, UNROLL): + vacc${i} = xnn_add_f32(vo${i}, vacc${i}); + $for i in range(0, UNROLL): + xnn_storeu_f32(output, vacc${i}); output += ${SIMD_SIZE}; + + input_row = (const float*) ((uintptr_t) input_row + ${CHANNELS} * sizeof(float)); + } + if (channels != 0) { + input_increment = ${ACCUMULATORS} * input_stride1; + const float* i0 = input_row; + $for i in range(1, ACCUMULATORS): + const float* i${i} = + (const float*) ((uintptr_t) input_row + ${i} * input_stride1); + xnn_simd_f32_t vacc[${UNROLL}]; + $for i in range(UNROLL): + vacc[${i}] = xnn_zero_f32(); + + const size_t num_full_chunks = channels >> ${LOG2_SIMD_SIZE}; + const size_t num_chunks = round_up_po2(channels, ${SIMD_SIZE}) >> ${LOG2_SIMD_SIZE}; + const size_t remainder = channels & ${SIMD_SIZE - 1}; + for (int r = k1; r > 0; r -= ${ACCUMULATORS}) { + $for N in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${N+1}) { + i${N} = zero; + } + if XNN_UNPREDICTABLE(r <= ${N+1}) { + i${N+1} = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + $for c in range(ACCUMULATORS): + xnn_simd_f32_t vin${c} = xnn_loadu_f32(&i${c}[i*${SIMD_SIZE}]); + $for c in range(ACCUMULATORS): + vin${c} = xnn_mul_f32(vin${c}, vin${c}); + $for c in range(ACCUMULATORS): + vacc[i] = xnn_add_f32(vin${c}, vacc[i]); + } + + if (remainder) { + $for c in range(ACCUMULATORS): + xnn_simd_f32_t vin${c} = xnn_load_tail_f32(&i${c}[num_full_chunks*${SIMD_SIZE}], remainder); + $for c in range(ACCUMULATORS): + vin${c} = xnn_mul_f32(vin${c}, vin${c}); + $for c in range(ACCUMULATORS): + vacc[num_full_chunks] = xnn_add_f32(vin${c}, vacc[num_full_chunks]); + } + $for N in range(ACCUMULATORS): + i${N} = (const float*) ((uintptr_t) i${N} + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = xnn_mul_f32(vacc[i], vscale); + } + + xnn_simd_f32_t vo[${UNROLL}]; + const float* o = output; + for (int i = 0; i < channels >> ${LOG2_SIMD_SIZE}; ++i) { + vo[i] = xnn_loadu_f32(o); o += ${SIMD_SIZE}; + } + for (int i = 0; i < channels >> ${LOG2_SIMD_SIZE}; ++i) { + vacc[i] = xnn_add_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> ${LOG2_SIMD_SIZE}; ++i) { + xnn_storeu_f32(output, vacc[i]); output += ${SIMD_SIZE}; + } + if (remainder) { + const size_t pos = num_full_chunks; + xnn_simd_f32_t vout = vacc[pos]; + const xnn_simd_f32_t vdata = xnn_load_tail_safe_f32(output, remainder); + vout = xnn_add_f32(vout, vdata); + xnn_store_tail_f32(output, vout, remainder); + } + } + } + } + } diff --git a/src/f32-rsum2/f32-rsum2.inc b/src/f32-rsum2/f32-rsum2.inc new file mode 100644 index 00000000000..61302bfa7df --- /dev/null +++ b/src/f32-rsum2/f32-rsum2.inc @@ -0,0 +1,66 @@ +// clang-format off +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rsum2_ukernel__neon_u4, 4, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rsum2_ukernel__neon_u8, 8, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rsum2_ukernel__neon_u8_acc2, 8, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rsum2_ukernel__neon_u12, 12, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rsum2_ukernel__neon_u12_acc3, 12, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rsum2_ukernel__neon_u16, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rsum2_ukernel__neon_u16_acc2, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_arm_neon, xnn_f32_rsum2_ukernel__neon_u16_acc4, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__sse2_u4, 4, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__sse2_u8, 8, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__sse2_u8_acc2, 8, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__sse2_u12, 12, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__sse2_u12_acc3, 12, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__sse2_u16, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__sse2_u16_acc2, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__sse2_u16_acc4, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rsum2_ukernel__avx_u8, 8, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rsum2_ukernel__avx_u16, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rsum2_ukernel__avx_u16_acc2, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rsum2_ukernel__avx_u24, 24, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rsum2_ukernel__avx_u24_acc3, 24, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rsum2_ukernel__avx_u32, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rsum2_ukernel__avx_u32_acc2, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_rsum2_ukernel__avx_u32_acc4, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + +#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rsum2_ukernel__avx512f_u16, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rsum2_ukernel__avx512f_u32, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rsum2_ukernel__avx512f_u32_acc2, 32, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rsum2_ukernel__avx512f_u48, 48, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rsum2_ukernel__avx512f_u48_acc3, 48, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rsum2_ukernel__avx512f_u64, 64, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rsum2_ukernel__avx512f_u64_acc2, 64, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f32_rsum2_ukernel__avx512f_u64_acc4, 64, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) + +#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__wasmsimd_u4, 4, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__wasmsimd_u8, 8, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__wasmsimd_u8_acc2, 8, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__wasmsimd_u12, 12, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__wasmsimd_u12_acc3, 12, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__wasmsimd_u16, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__wasmsimd_u16_acc2, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__wasmsimd_u16_acc4, 16, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__scalar_u1, 1, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__scalar_u2, 2, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__scalar_u2_acc2, 2, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__scalar_u3, 3, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__scalar_u3_acc3, 3, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__scalar_u4, 4, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__scalar_u4_acc2, 4, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) +XNN_UKERNEL(xnn_arch_none, xnn_f32_rsum2_ukernel__scalar_u4_acc4, 4, false, float, float, struct xnn_f32_scale_params, xnn_init_f32_scale_scalar_params) diff --git a/src/f32-rsum2/gen/f32-rsum2-avx-u8.c b/src/f32-rsum2/gen/f32-rsum2-avx-u8.c new file mode 100644 index 00000000000..f926cbf7411 --- /dev/null +++ b/src/f32-rsum2/gen/f32-rsum2-avx-u8.c @@ -0,0 +1,378 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-avx.h" + +static XNN_INLINE float load_tail_reduce_add_squared_f32(xnn_simd_f32_t acc, + const float* input, + size_t num_elements) { + assert(num_elements < xnn_simd_size_f32); + if (num_elements != 0) { + xnn_simd_f32_t tail = xnn_load_tail_safe_f32(input, num_elements); + tail = xnn_mul_f32(tail, tail); + acc = xnn_add_f32(acc, tail); + } + return xnn_reduce_add_f32(acc); +} + +void xnn_f32_rsum2_ukernel__avx_u8( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + input += 8; + + vt0 = xnn_mul_f32(vt0, vt0); + + vacc0 = xnn_add_f32(vacc0, vt0); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx_u16_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx_u16( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx_u24_acc3( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + for (; batch >= 24 * sizeof(float); batch -= 24 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 24; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx_u24( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 24 * sizeof(float); batch -= 24 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 24; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx_u32_acc4( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 32; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + vacc3 = xnn_add_f32(vacc3, vt3); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc2 = xnn_add_f32(vacc2, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc1 = xnn_add_f32(vacc1, vacc3); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx_u32_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 32; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc1 = xnn_add_f32(vacc1, vt3); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx_u32( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 32; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc0 = xnn_add_f32(vacc0, vt3); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 32) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 8; + batch -= 32; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} diff --git a/src/f32-rsum2/gen/f32-rsum2-avx512f-u16.c b/src/f32-rsum2/gen/f32-rsum2-avx512f-u16.c new file mode 100644 index 00000000000..f2718a049b1 --- /dev/null +++ b/src/f32-rsum2/gen/f32-rsum2-avx512f-u16.c @@ -0,0 +1,378 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-avx512f.h" + +static XNN_INLINE float load_tail_reduce_add_squared_f32(xnn_simd_f32_t acc, + const float* input, + size_t num_elements) { + assert(num_elements < xnn_simd_size_f32); + if (num_elements != 0) { + xnn_simd_f32_t tail = xnn_load_tail_safe_f32(input, num_elements); + tail = xnn_mul_f32(tail, tail); + acc = xnn_add_f32(acc, tail); + } + return xnn_reduce_add_f32(acc); +} + +void xnn_f32_rsum2_ukernel__avx512f_u16( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + + vacc0 = xnn_add_f32(vacc0, vt0); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx512f_u32_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 32; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx512f_u32( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 32; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx512f_u48_acc3( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + for (; batch >= 48 * sizeof(float); batch -= 48 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 48; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx512f_u48( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 48 * sizeof(float); batch -= 48 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 48; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx512f_u64_acc4( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + for (; batch >= 64 * sizeof(float); batch -= 64 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 64; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + vacc3 = xnn_add_f32(vacc3, vt3); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc2 = xnn_add_f32(vacc2, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc1 = xnn_add_f32(vacc1, vacc3); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx512f_u64_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 64 * sizeof(float); batch -= 64 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 64; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc1 = xnn_add_f32(vacc1, vt3); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__avx512f_u64( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 64 * sizeof(float); batch -= 64 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 64; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc0 = xnn_add_f32(vacc0, vt3); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 64) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 16; + batch -= 64; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} diff --git a/src/f32-rsum2/gen/f32-rsum2-neon.c b/src/f32-rsum2/gen/f32-rsum2-neon.c new file mode 100644 index 00000000000..ac81042370d --- /dev/null +++ b/src/f32-rsum2/gen/f32-rsum2-neon.c @@ -0,0 +1,387 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-neon.h" + +static XNN_INLINE float load_tail_reduce_add_squared_f32(xnn_simd_f32_t acc, + const float* input, + size_t num_elements) { + assert(num_elements < xnn_simd_size_f32); + float32x2_t result = + vadd_f32(vget_low_f32(acc), vget_high_f32(acc)); + if XNN_UNLIKELY (num_elements & 2) { + float32x2_t vt = vld1_f32(input); + input += 2; + vt = vmul_f32(vt, vt); + result = vadd_f32(result, vt); + } + result = vpadd_f32(result, result); + if XNN_UNLIKELY (num_elements & 1) { + float32x2_t vt = vld1_dup_f32(input); + vt = vmul_f32(vt, vt); + result = vadd_f32(result, vt); + } + return vget_lane_f32(result, 0); +} + +void xnn_f32_rsum2_ukernel__neon_u4( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 4 * sizeof(float); batch -= 4 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + input += 4; + + vt0 = xnn_mul_f32(vt0, vt0); + + vacc0 = xnn_add_f32(vacc0, vt0); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__neon_u8_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 8; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__neon_u8( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 8; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__neon_u12_acc3( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + for (; batch >= 12 * sizeof(float); batch -= 12 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 12; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__neon_u12( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 12 * sizeof(float); batch -= 12 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 12; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__neon_u16_acc4( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + vacc3 = xnn_add_f32(vacc3, vt3); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc2 = xnn_add_f32(vacc2, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc1 = xnn_add_f32(vacc1, vacc3); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__neon_u16_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc1 = xnn_add_f32(vacc1, vt3); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__neon_u16( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc0 = xnn_add_f32(vacc0, vt3); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} diff --git a/src/f32-rsum2/gen/f32-rsum2-scalar-u1.c b/src/f32-rsum2/gen/f32-rsum2-scalar-u1.c new file mode 100644 index 00000000000..d2160dd0293 --- /dev/null +++ b/src/f32-rsum2/gen/f32-rsum2-scalar-u1.c @@ -0,0 +1,378 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-scalar.h" + +static XNN_INLINE float load_tail_reduce_add_squared_f32(xnn_simd_f32_t acc, + const float* input, + size_t num_elements) { + assert(num_elements < xnn_simd_size_f32); + if (num_elements != 0) { + xnn_simd_f32_t tail = xnn_load_tail_safe_f32(input, num_elements); + tail = xnn_mul_f32(tail, tail); + acc = xnn_add_f32(acc, tail); + } + return xnn_reduce_add_f32(acc); +} + +void xnn_f32_rsum2_ukernel__scalar_u1( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 1 * sizeof(float); batch -= 1 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + input += 1; + + vt0 = xnn_mul_f32(vt0, vt0); + + vacc0 = xnn_add_f32(vacc0, vt0); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__scalar_u2_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 2 * sizeof(float); batch -= 2 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 2; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__scalar_u2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 2 * sizeof(float); batch -= 2 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 2; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__scalar_u3_acc3( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + for (; batch >= 3 * sizeof(float); batch -= 3 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 3; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__scalar_u3( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 3 * sizeof(float); batch -= 3 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 3; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__scalar_u4_acc4( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + for (; batch >= 4 * sizeof(float); batch -= 4 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 4; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + vacc3 = xnn_add_f32(vacc3, vt3); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc2 = xnn_add_f32(vacc2, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc1 = xnn_add_f32(vacc1, vacc3); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__scalar_u4_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 4 * sizeof(float); batch -= 4 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 4; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc1 = xnn_add_f32(vacc1, vt3); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__scalar_u4( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 4 * sizeof(float); batch -= 4 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 4; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc0 = xnn_add_f32(vacc0, vt3); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 4) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 1; + batch -= 4; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} diff --git a/src/f32-rsum2/gen/f32-rsum2-sse2-u4.c b/src/f32-rsum2/gen/f32-rsum2-sse2-u4.c new file mode 100644 index 00000000000..079c19fb5d5 --- /dev/null +++ b/src/f32-rsum2/gen/f32-rsum2-sse2-u4.c @@ -0,0 +1,379 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-sse2.h" + +static XNN_INLINE float load_tail_reduce_add_squared_f32(xnn_simd_f32_t acc, + const float* input, + size_t num_elements) { + assert(num_elements < xnn_simd_size_f32); + for (; num_elements > 0; num_elements -= 1) { + __m128 vt = _mm_load_ss(input); + input += 1; + vt = _mm_mul_ps(vt, vt); + acc = _mm_add_ss(acc, vt); + } + return xnn_reduce_add_f32(acc); +} + +void xnn_f32_rsum2_ukernel__sse2_u4( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 4 * sizeof(float); batch -= 4 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + input += 4; + + vt0 = xnn_mul_f32(vt0, vt0); + + vacc0 = xnn_add_f32(vacc0, vt0); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__sse2_u8_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 8; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__sse2_u8( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 8; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__sse2_u12_acc3( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + for (; batch >= 12 * sizeof(float); batch -= 12 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 12; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__sse2_u12( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 12 * sizeof(float); batch -= 12 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 12; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__sse2_u16_acc4( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + vacc3 = xnn_add_f32(vacc3, vt3); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc2 = xnn_add_f32(vacc2, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc1 = xnn_add_f32(vacc1, vacc3); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__sse2_u16_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc1 = xnn_add_f32(vacc1, vt3); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__sse2_u16( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc0 = xnn_add_f32(vacc0, vt3); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} diff --git a/src/f32-rsum2/gen/f32-rsum2-wasmsimd-u4.c b/src/f32-rsum2/gen/f32-rsum2-wasmsimd-u4.c new file mode 100644 index 00000000000..d00236c828c --- /dev/null +++ b/src/f32-rsum2/gen/f32-rsum2-wasmsimd-u4.c @@ -0,0 +1,386 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-rsum2/simd.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-wasmsimd.h" + +static XNN_INLINE float load_tail_reduce_add_squared_f32(xnn_simd_f32_t acc, + const float* input, + size_t num_elements) { + assert(num_elements < xnn_simd_size_f32); + acc = wasm_f32x4_add(acc, wasm_v64x2_shuffle(acc, acc, 1, 1)); + if XNN_UNLIKELY(num_elements & 2) { + v128_t vt = wasm_v128_load64_zero(input); + input += 2; + vt = wasm_f32x4_mul(vt, vt); + acc = wasm_f32x4_add(acc, vt); + } + acc = wasm_f32x4_add(acc, wasm_v32x4_shuffle(acc, acc, 1, 1, 1, 1)); + if XNN_UNLIKELY(num_elements & 1) { + v128_t vt = wasm_v128_load32_zero(input); + vt = wasm_f32x4_mul(vt, vt); + acc = wasm_f32x4_add(acc, vt); + } + return wasm_f32x4_extract_lane(acc, 0); +} + +void xnn_f32_rsum2_ukernel__wasmsimd_u4( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 4 * sizeof(float); batch -= 4 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + input += 4; + + vt0 = xnn_mul_f32(vt0, vt0); + + vacc0 = xnn_add_f32(vacc0, vt0); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__wasmsimd_u8_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 8; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__wasmsimd_u8( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + input += 8; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__wasmsimd_u12_acc3( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + for (; batch >= 12 * sizeof(float); batch -= 12 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 12; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__wasmsimd_u12( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 12 * sizeof(float); batch -= 12 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + input += 12; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__wasmsimd_u16_acc4( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + xnn_simd_f32_t vacc2 = xnn_zero_f32(); + xnn_simd_f32_t vacc3 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc2 = xnn_add_f32(vacc2, vt2); + vacc3 = xnn_add_f32(vacc3, vt3); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc2 = xnn_add_f32(vacc2, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc2); + vacc1 = xnn_add_f32(vacc1, vacc3); + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__wasmsimd_u16_acc2( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + xnn_simd_f32_t vacc1 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc1 = xnn_add_f32(vacc1, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc1 = xnn_add_f32(vacc1, vt3); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc1 = xnn_add_f32(vacc1, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + vacc0 = xnn_add_f32(vacc0, vacc1); + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} + +void xnn_f32_rsum2_ukernel__wasmsimd_u16( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + xnn_simd_f32_t vacc0 = xnn_zero_f32(); + for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + xnn_simd_f32_t vt1 = xnn_loadu_f32(input + 1 * xnn_simd_size_f32); + xnn_simd_f32_t vt2 = xnn_loadu_f32(input + 2 * xnn_simd_size_f32); + xnn_simd_f32_t vt3 = xnn_loadu_f32(input + 3 * xnn_simd_size_f32); + input += 16; + + vt0 = xnn_mul_f32(vt0, vt0); + vt1 = xnn_mul_f32(vt1, vt1); + vt2 = xnn_mul_f32(vt2, vt2); + vt3 = xnn_mul_f32(vt3, vt3); + + vacc0 = xnn_add_f32(vacc0, vt0); + vacc0 = xnn_add_f32(vacc0, vt1); + vacc0 = xnn_add_f32(vacc0, vt2); + vacc0 = xnn_add_f32(vacc0, vt3); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + if (batch >= 16) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += 4; + batch -= 16; + vt = xnn_mul_f32(vt, vt); + vacc0 = xnn_add_f32(vacc0, vt); + } + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; +} diff --git a/src/f32-rsum2/simd.c.in b/src/f32-rsum2/simd.c.in new file mode 100644 index 00000000000..777e15442e7 --- /dev/null +++ b/src/f32-rsum2/simd.c.in @@ -0,0 +1,117 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(",")) +$SIMD_SIZE = BATCH_TILES[0] +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/reduce.h" +#include "src/xnnpack/simd/f32-${ARCH}.h" + +static XNN_INLINE float load_tail_reduce_add_squared_f32(xnn_simd_f32_t acc, + const float* input, + size_t num_elements) { + assert(num_elements < xnn_simd_size_f32); + $if ARCH == "neon": + float32x2_t result = + vadd_f32(vget_low_f32(acc), vget_high_f32(acc)); + if XNN_UNLIKELY (num_elements & 2) { + float32x2_t vt = vld1_f32(input); + input += 2; + vt = vmul_f32(vt, vt); + result = vadd_f32(result, vt); + } + result = vpadd_f32(result, result); + if XNN_UNLIKELY (num_elements & 1) { + float32x2_t vt = vld1_dup_f32(input); + vt = vmul_f32(vt, vt); + result = vadd_f32(result, vt); + } + return vget_lane_f32(result, 0); + $elif ARCH == "sse2": + for (; num_elements > 0; num_elements -= 1) { + __m128 vt = _mm_load_ss(input); + input += 1; + vt = _mm_mul_ps(vt, vt); + acc = _mm_add_ss(acc, vt); + } + return xnn_reduce_add_f32(acc); + $elif ARCH == "wasmsimd": + acc = wasm_f32x4_add(acc, wasm_v64x2_shuffle(acc, acc, 1, 1)); + if XNN_UNLIKELY(num_elements & 2) { + v128_t vt = wasm_v128_load64_zero(input); + input += 2; + vt = wasm_f32x4_mul(vt, vt); + acc = wasm_f32x4_add(acc, vt); + } + acc = wasm_f32x4_add(acc, wasm_v32x4_shuffle(acc, acc, 1, 1, 1, 1)); + if XNN_UNLIKELY(num_elements & 1) { + v128_t vt = wasm_v128_load32_zero(input); + vt = wasm_f32x4_mul(vt, vt); + acc = wasm_f32x4_add(acc, vt); + } + return wasm_f32x4_extract_lane(acc, 0); + $else: + if (num_elements != 0) { + xnn_simd_f32_t tail = xnn_load_tail_safe_f32(input, num_elements); + tail = xnn_mul_f32(tail, tail); + acc = xnn_add_f32(acc, tail); + } + return xnn_reduce_add_f32(acc); +} +$for BATCH_TILE in BATCH_TILES: + $assert BATCH_TILE % SIMD_SIZE == 0 + $assert BATCH_TILE >= SIMD_SIZE + $SIMD_TILE = BATCH_TILE // SIMD_SIZE + $ACCUMULATORS = SIMD_TILE + $while ACCUMULATORS: + $ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS + + void xnn_f32_rsum2_ukernel__${ARCH}_u${BATCH_TILE}${ACC_SUFFIX}( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* restrict params) { + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + $for A in range(ACCUMULATORS): + xnn_simd_f32_t vacc${A} = xnn_zero_f32(); + for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) { + xnn_simd_f32_t vt0 = xnn_loadu_f32(input); + $for N in range(1, SIMD_TILE): + xnn_simd_f32_t vt${N} = xnn_loadu_f32(input + ${N} * xnn_simd_size_f32); + input += ${BATCH_TILE}; + + $for N in range(0, SIMD_TILE): + vt${N} = xnn_mul_f32(vt${N}, vt${N}); + + $for N in range(SIMD_TILE): + vacc${N % ACCUMULATORS} = xnn_add_f32(vacc${N % ACCUMULATORS}, vt${N}); + } + $for N in range(0, SIMD_TILE - 1): + if (batch >= ${SIMD_SIZE * 4}) { + xnn_simd_f32_t vt = xnn_loadu_f32(input); + input += ${SIMD_SIZE}; + batch -= ${SIMD_SIZE * 4}; + vt = xnn_mul_f32(vt, vt); + vacc${N % ACCUMULATORS} = xnn_add_f32(vacc${N % ACCUMULATORS}, vt); + } + $ACC_SLICE = (ACCUMULATORS + 1)//2 + $while ACC_SLICE > 0: + $for A in range(0, ACC_SLICE): + $if A + ACC_SLICE < ACCUMULATORS: + vacc${A} = xnn_add_f32(vacc${A}, vacc${A + ACC_SLICE}); + $ACC_SLICE //= 2 + const float vscale = params->scalar.scale; + float vresult = load_tail_reduce_add_squared_f32( + vacc0, input, batch >> XNN_LOG2_SIZEOF_FLOAT); + *output += vresult * vscale; + } + $ACCUMULATORS //= 2 diff --git a/src/operator-run.c b/src/operator-run.c index 44df996f0b5..00fac035a74 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -11,6 +11,7 @@ #include #include #include +#include #include #include "include/xnnpack.h" @@ -2040,6 +2041,48 @@ void xnn_compute_floating_point_softmax( context->vmulc_ukernel(n, y, &y_scale, y, &context->minmax_params); } +void xnn_compute_normalize(struct normalize_context* restrict context, + size_t batch_index) { + const void* x = + (const void*)((uintptr_t)context->x + context->x_stride * batch_index); + void* y = (void*)((uintptr_t)context->y + context->y_stride * batch_index); + const size_t n = context->n; + + // First pass: reduce sum squared. + float sum_of_squares = 0.0f; + context->rsum2_ukernel(n, x, &sum_of_squares, &context->rsum2_params); + + // Avoid any negativity due to rounding error. + if (sum_of_squares < 0.0f) { + sum_of_squares = 0.0f; + } + + // Second pass: scale. + float scale_fp32 = 0; + switch (context->norm_type) { + case xnn_norm_l2: + scale_fp32 = 1.0f / sqrtf(context->epsilon + sum_of_squares); + break; + case xnn_norm_rms: + scale_fp32 = 1.0f / sqrtf(context->epsilon + + sum_of_squares / context->num_channels); + break; + default: + XNN_UNREACHABLE; + } + union { + float fp32; + xnn_float16 fp16; + } scale; + context->convert_scale(scale_fp32, &scale); + context->vmulc_ukernel(n, x, &scale, y, &context->minmax_params); + + // Optional third pass: scale with vector. + if (context->scale != NULL) { + context->vmul_ukernel(n, y, context->scale, y, /*params=*/NULL); + } +} + void xnn_compute_vmulcaddc(struct vmulcaddc_context* restrict context, size_t batch_start, size_t batch_size) { const size_t x_stride = context->x_stride; diff --git a/src/operator-utils.c b/src/operator-utils.c index 71b0136ba2a..27d0508672f 100644 --- a/src/operator-utils.c +++ b/src/operator-utils.c @@ -201,6 +201,20 @@ enum xnn_status xnn_destroy_operator(xnn_operator_t op) } +const char* xnn_norm_type_to_string(enum xnn_norm_type norm_type) +{ + switch (norm_type) { + case xnn_norm_l2: + return "L2"; + case xnn_norm_rms: + return "RMS"; + case xnn_norm_invalid: + return "invalid"; + } + XNN_UNREACHABLE; + return "unknown"; +} + const char* xnn_unary_operator_to_string(enum xnn_unary_operator op) { switch (op) { @@ -316,6 +330,8 @@ enum xnn_operator_type xnn_reduce_operator_to_operator_type(enum xnn_reduce_oper return xnn_operator_type_mean_nd; case xnn_reduce_sum: return xnn_operator_type_sum_nd; + case xnn_reduce_sum_squared: + return xnn_operator_type_sum_squared_nd; case xnn_reduce_max: return xnn_operator_type_reduce_max_nd; case xnn_reduce_min: diff --git a/src/operators/normalize-nc.c b/src/operators/normalize-nc.c new file mode 100644 index 00000000000..fc7fa7c54ce --- /dev/null +++ b/src/operators/normalize-nc.c @@ -0,0 +1,312 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +#include "include/xnnpack.h" +#include "src/xnnpack/allocator.h" +#include "src/xnnpack/common.h" +#include "src/xnnpack/compute.h" +#include "src/xnnpack/config-types.h" +#include "src/xnnpack/config.h" +#include "src/xnnpack/log.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microfnptr.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/operator-type.h" +#include "src/xnnpack/operator-utils.h" +#include "src/xnnpack/operator.h" +#include "src/xnnpack/params.h" +#include + +static enum xnn_status create_normalize_nc_floating_point( + enum xnn_norm_type norm_type, float epsilon, uint32_t flags, + const struct xnn_reduce_config* rsum2_config, + const struct xnn_binary_elementwise_config* vmul_config, + enum xnn_operator_type operator_type, xnn_operator_t* normalize_op_out) { + xnn_operator_t normalize_op = NULL; + enum xnn_status status = xnn_status_uninitialized; + + if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { + xnn_log_error("failed to create %s operator: XNNPACK is not initialized", + xnn_operator_type_to_string(operator_type)); + goto error; + } + + status = xnn_status_out_of_memory; + + normalize_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator)); + if (normalize_op == NULL) { + xnn_log_error("failed to allocate %zu bytes for %s operator descriptor", + sizeof(struct xnn_operator), + xnn_operator_type_to_string(operator_type)); + goto error; + } + normalize_op->compute = + xnn_allocate_zero_memory(sizeof(struct compute_parameters)); + if (normalize_op->compute == NULL) { + xnn_log_error("failed to allocate %zu bytes for %s operator descriptor", + sizeof(struct compute_parameters), + xnn_operator_type_to_string(operator_type)); + goto error; + } + normalize_op->num_compute_invocations = 1; + + normalize_op->type = operator_type; + normalize_op->flags = flags; + normalize_op->reduce_config = rsum2_config; + normalize_op->vmul_config = vmul_config; + normalize_op->norm_type = norm_type; + normalize_op->normalize_epsilon = epsilon; + + normalize_op->state = xnn_run_state_invalid; + + *normalize_op_out = normalize_op; + return xnn_status_success; + +error: + xnn_delete_operator(normalize_op); + return status; +} + +enum xnn_status xnn_create_normalize_nc_f16(enum xnn_norm_type norm_type, + float epsilon, uint32_t flags, + xnn_operator_t* normalize_op_out) { + const struct xnn_reduce_config* rsum2_config = + xnn_init_f16_f32acc_rsum2_config(); + if (rsum2_config == NULL) { + xnn_log_error( + "failed to create %s operator: unsupported hardware configuration", + xnn_operator_type_to_string(xnn_operator_type_normalize_nc_f16)); + return xnn_status_unsupported_hardware; + } + + const struct xnn_binary_elementwise_config* vmul_config = + xnn_init_f16_vmul_config(); + if (vmul_config == NULL) { + xnn_log_error( + "failed to create %s operator: unsupported hardware configuration", + xnn_operator_type_to_string(xnn_operator_type_normalize_nc_f16)); + return xnn_status_unsupported_hardware; + } + + return create_normalize_nc_floating_point( + norm_type, epsilon, flags, rsum2_config, vmul_config, + xnn_operator_type_normalize_nc_f16, normalize_op_out); +} + +enum xnn_status xnn_create_normalize_nc_f32(enum xnn_norm_type norm_type, + float epsilon, uint32_t flags, + xnn_operator_t* normalize_op_out) { + const struct xnn_reduce_config* rsum2_config = xnn_init_f32_rsum2_config(); + if (rsum2_config == NULL) { + xnn_log_error( + "failed to create %s operator: unsupported hardware configuration", + xnn_operator_type_to_string(xnn_operator_type_normalize_nc_f32)); + return xnn_status_unsupported_hardware; + } + + const struct xnn_binary_elementwise_config* vmul_config = + xnn_init_f32_vmul_config(); + if (vmul_config == NULL) { + xnn_log_error( + "failed to create %s operator: unsupported hardware configuration", + xnn_operator_type_to_string(xnn_operator_type_normalize_nc_f32)); + return xnn_status_unsupported_hardware; + } + + return create_normalize_nc_floating_point( + norm_type, epsilon, flags, rsum2_config, vmul_config, + xnn_operator_type_normalize_nc_f32, normalize_op_out); +} + +static enum xnn_status reshape_normalize_nc_floating_point( + xnn_operator_t normalize_op, enum xnn_operator_type expected_operator_type, + size_t channels, size_t input_stride, size_t output_stride, + size_t batch_size, uint32_t log2_element_size, + xnn_rsum2_ukernel_fn rsum2_ukernel, + const struct xnn_binary_elementwise_config* vmul, + xnn_convert_scale_fn convert_scale, const void* rsum2_params, + size_t rsum2_params_size, const void* minmax_params, + size_t minmax_params_size) { + if (vmul == NULL) { + return xnn_status_unsupported_hardware; + } + if (normalize_op->type != expected_operator_type) { + xnn_log_error( + "failed to reshape operator: operator type mismatch (expected %s, got " + "%s)", + xnn_operator_type_to_string(expected_operator_type), + xnn_operator_type_to_string_v2(normalize_op)); + return xnn_status_invalid_parameter; + } + normalize_op->state = xnn_run_state_invalid; + + if (channels == 0) { + xnn_log_error( + "failed to create %s operator with %zu channels: number of channels " + "must be non-zero", + xnn_operator_type_to_string(expected_operator_type), channels); + return xnn_status_invalid_parameter; + } + + if (input_stride < channels) { + xnn_log_error( + "failed to create %s operator with input element stride of %zu: " + "stride must be at least as large as the number of channels (%zu)", + xnn_operator_type_to_string(expected_operator_type), input_stride, + channels); + return xnn_status_invalid_parameter; + } + + if (output_stride < channels) { + xnn_log_error( + "failed to create %s operator with output element stride of %zu: " + "stride must be at least as large as the number of channels (%zu)", + xnn_operator_type_to_string(expected_operator_type), output_stride, + channels); + return xnn_status_invalid_parameter; + } + + normalize_op->channels = channels; + normalize_op->input_pixel_stride = input_stride; + normalize_op->output_pixel_stride = output_stride; + + if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { + xnn_log_error("failed to reshape %s operator: XNNPACK is not initialized", + xnn_operator_type_to_string(expected_operator_type)); + return xnn_status_uninitialized; + } + + if (batch_size == 0) { + normalize_op->state = xnn_run_state_skip; + return xnn_status_success; + } + + struct normalize_context* context = &normalize_op->context.normalize; + *context = (struct normalize_context){ + .n = normalize_op->channels << log2_element_size, + .x_stride = normalize_op->input_pixel_stride << log2_element_size, + .y_stride = normalize_op->output_pixel_stride << log2_element_size, + .num_channels = normalize_op->channels, + .rsum2_ukernel = rsum2_ukernel, + .vmul_ukernel = vmul->op_ukernel, + .vmulc_ukernel = vmul->opc_ukernel, + .convert_scale = convert_scale, + .norm_type = normalize_op->norm_type, + .epsilon = normalize_op->normalize_epsilon, + }; + if (vmul->opc_ukernel != NULL) { + context->vmulc_ukernel = vmul->opc_ukernel; + }; + if (rsum2_params_size > 0) { + memcpy(&context->rsum2_params, rsum2_params, rsum2_params_size); + } + if (minmax_params_size > 0) { + memcpy(&context->minmax_params, minmax_params, minmax_params_size); + } + normalize_op->compute[0].type = xnn_parallelization_type_1d; + normalize_op->compute[0].task_1d = + (pthreadpool_task_1d_t)xnn_compute_normalize; + normalize_op->compute[0].range[0] = batch_size; + normalize_op->state = xnn_run_state_needs_setup; + + return xnn_status_success; +} + +static enum xnn_status setup_normalize_nc_floating_point( + xnn_operator_t normalize_op, enum xnn_operator_type expected_operator_type, + const void* input, const void* scale, void* output) { + if (normalize_op->type != expected_operator_type) { + xnn_log_error( + "failed to setup operator: operator type mismatch (expected %s, got " + "%s)", + xnn_operator_type_to_string(expected_operator_type), + xnn_operator_type_to_string_v2(normalize_op)); + return xnn_status_invalid_parameter; + } + + switch (normalize_op->state) { + case xnn_run_state_skip: + return xnn_status_success; + case xnn_run_state_invalid: + xnn_log_error( + "failed to setup %s operator: operator has not been reshaped yet", + xnn_operator_type_to_string_v2(normalize_op)); + return xnn_status_invalid_state; + case xnn_run_state_needs_setup: + // Operator has been reshaped, but not setup, continue with setup. + case xnn_run_state_ready: + // Operator has been reshaped, and we are setting up with different + // pointers. + break; + } + + normalize_op->context.normalize.x = input; + normalize_op->context.normalize.y = output; + normalize_op->context.normalize.scale = scale; + normalize_op->state = xnn_run_state_ready; + + return xnn_status_success; +} + +enum xnn_status xnn_setup_normalize_nc_f16(xnn_operator_t normalize_op, + const void* input, const void* scale, + void* output) { + return setup_normalize_nc_floating_point( + normalize_op, xnn_operator_type_normalize_nc_f16, input, scale, output); +} + +enum xnn_status xnn_setup_normalize_nc_f32(xnn_operator_t normalize_op, + const float* input, + const float* scale, float* output) { + return setup_normalize_nc_floating_point( + normalize_op, xnn_operator_type_normalize_nc_f32, input, scale, output); +} + +static void convert_scale_to_fp16(float input, void* output) { + *(xnn_float16*)output = xnn_float16_from_float(input); +} + +static void convert_scale_to_fp32(float input, void* output) { + *(float*)output = input; +} + +enum xnn_status xnn_reshape_normalize_nc_f16( + xnn_operator_t normalize_op, size_t channels, size_t input_stride, + size_t output_stride, size_t batch_size, pthreadpool_t threadpool) { + const struct xnn_f16_f32acc_scale_params rsum2_params = { + .scalar = {.scale = 1.0f}}; + + return reshape_normalize_nc_floating_point( + normalize_op, xnn_operator_type_normalize_nc_f16, channels, input_stride, + output_stride, batch_size, + /*log2_element_size=*/XNN_LOG2_SIZEOF_HALF, + normalize_op->reduce_config->ukernel, normalize_op->vmul_config, + /*convert_scale=*/convert_scale_to_fp16, + /*rsum2_params=*/&rsum2_params, + /*rsum2_params_size=*/sizeof(rsum2_params), + /*minmax_params=*/NULL, /*minmax_params_size=*/0); +} + +enum xnn_status xnn_reshape_normalize_nc_f32( + xnn_operator_t normalize_op, size_t channels, size_t input_stride, + size_t output_stride, size_t batch_size, pthreadpool_t threadpool) { + const struct xnn_f32_scale_params rsum2_params = {.scalar = {.scale = 1.0f}}; + + return reshape_normalize_nc_floating_point( + normalize_op, xnn_operator_type_normalize_nc_f32, channels, input_stride, + output_stride, batch_size, + /*log2_element_size=*/XNN_LOG2_SIZEOF_FLOAT, + normalize_op->reduce_config->ukernel, normalize_op->vmul_config, + /*convert_scale=*/convert_scale_to_fp32, + /*rsum2_params=*/&rsum2_params, + /*rsum2_params_size=*/sizeof(rsum2_params), + /*minmax_params=*/NULL, /*minmax_params_size=*/0); +} diff --git a/src/operators/reduce-nd.c b/src/operators/reduce-nd.c index 3a406bc5435..8e53601995c 100644 --- a/src/operators/reduce-nd.c +++ b/src/operators/reduce-nd.c @@ -16,16 +16,16 @@ #include "src/xnnpack/compute.h" #include "src/xnnpack/config-types.h" #include "src/xnnpack/config.h" -#include "src/xnnpack/operator-utils.h" -#include "src/xnnpack/reference-config.h" #include "src/xnnpack/datatype.h" #include "src/xnnpack/log.h" #include "src/xnnpack/microkernel-type.h" #include "src/xnnpack/microparams.h" #include "src/xnnpack/normalization.h" #include "src/xnnpack/operator-type.h" +#include "src/xnnpack/operator-utils.h" #include "src/xnnpack/operator.h" #include "src/xnnpack/params.h" +#include "src/xnnpack/reference-config.h" #include static enum xnn_status create_reduce_nd( @@ -434,42 +434,62 @@ enum xnn_status xnn_create_reduce_nd( operator_type == xnn_operator_type_reduce_min_nd); // Load configs. - const struct xnn_reduce_config* config = NULL; + const struct xnn_reduce_config* reduce_config = NULL; const struct xnn_unary_elementwise_config* cvt_config = NULL; const struct xnn_xx_fill_config* fill_config = NULL; uint32_t log2_data_element_size = xnn_datatype_log2_size_bytes(datatype); uint32_t log2_accumulator_element_size; switch (datatype) { case xnn_datatype_fp16: { + switch (operator_type) { + case xnn_operator_type_sum_nd: + case xnn_operator_type_mean_nd: + reduce_config = xnn_init_f16_f32acc_rsum_config(); + break; + case xnn_operator_type_sum_squared_nd: + reduce_config = xnn_init_f16_f32acc_rsum2_config(); + break; + case xnn_operator_type_reduce_min_nd: + reduce_config = xnn_init_f16_rmin_config(); + break; + case xnn_operator_type_reduce_max_nd: + reduce_config = xnn_init_f16_rmax_config(); + break; + default: + break; + } if (is_minmax) { log2_accumulator_element_size = 1; fill_config = xnn_init_xx_fill_config(); cvt_config = cvt_unused; - - if (operator_type == xnn_operator_type_reduce_min_nd) { - config = xnn_init_f16_rmin_config(); - } else { // max - config = xnn_init_f16_rmax_config(); - } } else { log2_accumulator_element_size = 2; - config = xnn_init_f16_f32acc_rsum_config(); fill_config = fill_unused; cvt_config = xnn_init_f32_to_f16_cvt_config(); } break; } case xnn_datatype_fp32: { + switch (operator_type) { + case xnn_operator_type_sum_nd: + case xnn_operator_type_mean_nd: + reduce_config = xnn_init_f32_rsum_config(); + break; + case xnn_operator_type_sum_squared_nd: + reduce_config = xnn_init_f32_rsum2_config(); + break; + case xnn_operator_type_reduce_min_nd: + reduce_config = xnn_init_f32_rmin_config(); + break; + case xnn_operator_type_reduce_max_nd: + reduce_config = xnn_init_f32_rmax_config(); + break; + default: + break; + } if (is_minmax) { fill_config = xnn_init_xx_fill_config(); - - if (operator_type == xnn_operator_type_reduce_min_nd) { - config = xnn_init_f32_rmin_config(); - } else { // max - config = xnn_init_f32_rmax_config(); - } } else { - config = xnn_init_f32_rsum_config(); fill_config = fill_unused; } @@ -478,6 +498,20 @@ enum xnn_status xnn_create_reduce_nd( break; } case xnn_datatype_qint8: { // qs8 + switch (operator_type) { + case xnn_operator_type_sum_nd: + case xnn_operator_type_mean_nd: + reduce_config = xnn_init_qs8_rsum_config(); + break; + case xnn_operator_type_reduce_min_nd: + reduce_config = xnn_init_s8_rmin_config(); + break; + case xnn_operator_type_reduce_max_nd: + reduce_config = xnn_init_s8_rmax_config(); + break; + default: + break; + } if (is_minmax) { assert(input_quantization->scale == output_quantization->scale); assert( @@ -485,15 +519,8 @@ enum xnn_status xnn_create_reduce_nd( log2_accumulator_element_size = 0; fill_config = xnn_init_xx_fill_config(); cvt_config = cvt_unused; - - if (operator_type == xnn_operator_type_reduce_min_nd) { - config = xnn_init_s8_rmin_config(); - } else { // max - config = xnn_init_s8_rmax_config(); - } } else { log2_accumulator_element_size = 2; - config = xnn_init_qs8_rsum_config(); fill_config = fill_unused; cvt_config = xnn_init_unary_reference_config( xnn_unary_convert, xnn_datatype_int32, xnn_datatype_qint8); @@ -501,6 +528,20 @@ enum xnn_status xnn_create_reduce_nd( break; } case xnn_datatype_quint8: { // qu8 + switch (operator_type) { + case xnn_operator_type_sum_nd: + case xnn_operator_type_mean_nd: + reduce_config = xnn_init_qu8_rsum_config(); + break; + case xnn_operator_type_reduce_min_nd: + reduce_config = xnn_init_u8_rmin_config(); + break; + case xnn_operator_type_reduce_max_nd: + reduce_config = xnn_init_u8_rmax_config(); + break; + default: + break; + } if (is_minmax) { assert(input_quantization->scale == output_quantization->scale); assert( @@ -508,15 +549,8 @@ enum xnn_status xnn_create_reduce_nd( log2_accumulator_element_size = 0; fill_config = xnn_init_xx_fill_config(); cvt_config = cvt_unused; - - if (operator_type == xnn_operator_type_reduce_min_nd) { - config = xnn_init_u8_rmin_config(); - } else { // max - config = xnn_init_u8_rmax_config(); - } } else { log2_accumulator_element_size = 2; - config = xnn_init_qu8_rsum_config(); // We just use an int32 -> qu8 conversion. This means we effectively // only have a 31-bit accumulator instead of 32-bit, but that seems // insignificant. @@ -533,7 +567,7 @@ enum xnn_status xnn_create_reduce_nd( }; // Check configs and restore unused pointers to NULL. - if (config == NULL || fill_config == NULL || cvt_config == NULL) { + if (reduce_config == NULL || fill_config == NULL || cvt_config == NULL) { xnn_log_error( "failed to create %s (%s) operator: unsupported hardware configuration", xnn_operator_type_to_string(operator_type), xnn_datatype_to_string(datatype)); @@ -546,9 +580,9 @@ enum xnn_status xnn_create_reduce_nd( struct xnn_reduce_params params; size_t params_size = 0; // Setup parameters - if (config->init.reduce) { - params_size = config->init.reduce(¶ms, input_quantization, - output_quantization); + if (reduce_config->init.reduce) { + params_size = reduce_config->init.reduce(¶ms, input_quantization, + output_quantization); } union xnn_unary_uparams cvt_params; size_t cvt_params_size = 0; @@ -560,15 +594,16 @@ enum xnn_status xnn_create_reduce_nd( // turn back to just return `create_reduce_nd` result. enum xnn_status status = xnn_status_invalid_state; status = create_reduce_nd( - flags, log2_data_element_size, log2_accumulator_element_size, operator_type, - config, fill_config, cvt_config, ¶ms, - params_size, &cvt_params, cvt_params_size, reduce_op_out); + flags, log2_data_element_size, log2_accumulator_element_size, + operator_type, reduce_config, fill_config, cvt_config, ¶ms, + params_size, &cvt_params, cvt_params_size, reduce_op_out); if (status != xnn_status_success) { return status; } if ((datatype == xnn_datatype_fp16 || datatype == xnn_datatype_fp32) && (reduce_operator_type == xnn_reduce_sum || + reduce_operator_type == xnn_reduce_sum_squared || reduce_operator_type == xnn_reduce_mean)) { (*reduce_op_out)->ukernel.type = xnn_microkernel_type_reduce2; } diff --git a/src/runtime.c b/src/runtime.c index 92872d43686..a9eda06fc8d 100644 --- a/src/runtime.c +++ b/src/runtime.c @@ -454,6 +454,7 @@ void propagate_rank( case xnn_node_type_static_reduce_max: case xnn_node_type_static_reduce_min: case xnn_node_type_static_sum: + case xnn_node_type_static_sum_squared: if (flags & XNN_FLAG_KEEP_DIMS) { output_value->shape.num_dims = input_value->shape.num_dims; } else { @@ -470,6 +471,7 @@ void propagate_rank( case xnn_node_type_unary_elementwise: case xnn_node_type_convert: case xnn_node_type_pack_lh: + case xnn_node_type_normalize: case xnn_node_type_softmax: case xnn_node_type_static_transpose: case xnn_node_type_static_constant_pad: @@ -526,7 +528,8 @@ static enum xnn_status create_runtime_impl( const uint32_t optimization_flags = XNN_FLAG_HINT_SPARSE_INFERENCE | XNN_FLAG_HINT_FP16_INFERENCE | XNN_FLAG_FORCE_FP16_INFERENCE | XNN_FLAG_NO_OPERATOR_FUSION | - XNN_FLAG_NO_INLINED_LHS_PACKING | XNN_FLAG_SLINKY_ENABLED; + XNN_FLAG_NO_INLINED_LHS_PACKING | XNN_FLAG_SLINKY_ENABLED | + XNN_FLAG_SLOW_CONSISTENT_ARITHMETIC; status = xnn_subgraph_optimize(subgraph, flags & optimization_flags); if (status != xnn_status_success) { xnn_log_error("failed to optimize subgraph"); diff --git a/src/subgraph.c b/src/subgraph.c index 2bf0d762684..2be576acaac 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -15,6 +15,7 @@ #include "include/experimental.h" #include "include/xnnpack.h" +#include "src/subgraph/subgraph-utils.h" #include "src/xnnpack/allocation-type.h" #include "src/xnnpack/allocator.h" #include "src/xnnpack/common.h" @@ -1007,34 +1008,36 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) { return false; } switch (node->type) { - case xnn_node_type_binary_elementwise: - case xnn_node_type_unary_elementwise: + case xnn_node_type_average_pooling_2d: case xnn_node_type_batch_matrix_multiply: + case xnn_node_type_binary_elementwise: case xnn_node_type_concatenate: case xnn_node_type_convert: - case xnn_node_type_average_pooling_2d: - case xnn_node_type_copy: case xnn_node_type_convolution_2d: + case xnn_node_type_copy: case xnn_node_type_deconvolution_2d: - case xnn_node_type_depthwise_convolution_2d: case xnn_node_type_depth_to_space_2d: + case xnn_node_type_depthwise_convolution_2d: case xnn_node_type_even_split: case xnn_node_type_fully_connected: case xnn_node_type_global_average_pooling_2d: case xnn_node_type_global_sum_pooling_2d: case xnn_node_type_max_pooling_2d: + case xnn_node_type_normalize: + case xnn_node_type_rope: case xnn_node_type_softmax: case xnn_node_type_space_to_depth_2d: case xnn_node_type_static_constant_pad: case xnn_node_type_static_mean: - case xnn_node_type_static_slice: - case xnn_node_type_static_sum: - case xnn_node_type_static_reduce_min: case xnn_node_type_static_reduce_max: + case xnn_node_type_static_reduce_min: case xnn_node_type_static_reshape: case xnn_node_type_static_resize_bilinear_2d: + case xnn_node_type_static_slice: + case xnn_node_type_static_sum: + case xnn_node_type_static_sum_squared: case xnn_node_type_static_transpose: - case xnn_node_type_rope: + case xnn_node_type_unary_elementwise: break; case xnn_node_type_pack_lh: if (xnn_init_x16_pack_lh_config() != NULL) { @@ -2123,26 +2126,57 @@ void xnn_subgraph_fuse_unary_quantized_into_lut(xnn_subgraph_t subgraph) { } } -void xnn_subgraph_clean_up(xnn_subgraph_t subgraph) { - // Count the number of consumers for each value. - xnn_subgraph_analyze_consumers_and_producers(subgraph); +static void recursive_remove_node(xnn_subgraph_t subgraph, uint32_t node_id) { + struct xnn_node* node = &subgraph->nodes[node_id]; - // Clear unreferenced values. - for (uint32_t i = 0; i < subgraph->num_values; i++) { - struct xnn_value* value = &subgraph->values[i]; - if (value->type == xnn_value_type_invalid) { + // Decrease the number of consumers on the inputs. + for (uint32_t input_id = 0; input_id < node->num_inputs; input_id++) { + if (is_repeated_input(node, input_id)) { continue; } - - if (!xnn_value_is_external_input(value->flags) && - value->num_consumers == 0) { - if (value->producer != XNN_INVALID_NODE_ID) { - struct xnn_node* producer = &subgraph->nodes[value->producer]; + struct xnn_value* input_value = &subgraph->values[node->inputs[input_id]]; + if (!xnn_value_is_external_input(input_value->flags) && + --input_value->num_consumers == 0) { + if (input_value->producer != XNN_INVALID_NODE_ID) { + struct xnn_node* producer = &subgraph->nodes[input_value->producer]; if (producer->num_outputs == 1) { - xnn_node_clear(&subgraph->nodes[value->producer]); + recursive_remove_node(subgraph, producer->id); + } + } + xnn_value_clear(input_value); + } + } + + xnn_node_clear(node); +} + +void xnn_subgraph_clean_up(xnn_subgraph_t subgraph) { + while (true) { + // Count the number of consumers for each value. + xnn_subgraph_analyze_consumers_and_producers(subgraph); + + // Clear unreferenced values. + bool changes = false; + for (uint32_t i = 0; i < subgraph->num_values; i++) { + struct xnn_value* value = &subgraph->values[i]; + if (value->type == xnn_value_type_invalid) { + continue; + } + + if (!xnn_value_is_external_input(value->flags) && + value->num_consumers == 0) { + if (value->producer != XNN_INVALID_NODE_ID) { + struct xnn_node* producer = &subgraph->nodes[value->producer]; + if (producer->num_outputs == 1) { + changes = true; + recursive_remove_node(subgraph, producer->id); + } } + xnn_value_clear(value); } - xnn_value_clear(value); + } + if (!changes) { + break; } } @@ -2162,6 +2196,7 @@ void xnn_subgraph_clean_up(xnn_subgraph_t subgraph) { uint32_t num_invalid_nodes = 0; bool changes = false; while (left + num_invalid_nodes < subgraph->num_nodes) { + num_invalid_nodes = 0; for (uint32_t i = left; i < subgraph->num_nodes; i++) { struct xnn_node* node = &subgraph->nodes[i]; @@ -2190,8 +2225,8 @@ void xnn_subgraph_clean_up(xnn_subgraph_t subgraph) { if (subgraph->nodes[left].type == xnn_node_type_invalid) { node->type = xnn_node_type_invalid; } else { - memcpy(&subgraph->nodes[left + 1], &subgraph->nodes[left], - (i - left) * sizeof(struct xnn_node)); + memmove(&subgraph->nodes[left + 1], &subgraph->nodes[left], + (i - left) * sizeof(struct xnn_node)); } subgraph->nodes[left] = tmp_node; } @@ -2282,6 +2317,491 @@ static bool convert_gemm_to_qduint8( return convert_to_qu8; } +static void swap_value_pointers(struct xnn_value** a, struct xnn_value** b) { + struct xnn_value* temp = *a; + *a = *b; + *b = temp; +} + +static float get_value_as_float(const void* data, enum xnn_datatype datatype) { + switch (datatype) { + case xnn_datatype_fp32: + return *(const float*)data; + case xnn_datatype_fp16: + return xnn_float16_to_float(*(const xnn_float16*)data); + case xnn_datatype_int32: + return *(const int32_t*)data; + default: + XNN_UNREACHABLE; + } +} + +enum xnn_status xnn_subgraph_optimize_common_subgraphs( + xnn_subgraph_t subgraph, uint32_t optimization_flags) { + // If we shouldn't change the numerics, then don't do anything. + if (optimization_flags & XNN_FLAG_SLOW_CONSISTENT_ARITHMETIC || + optimization_flags & XNN_FLAG_NO_OPERATOR_FUSION) { + return xnn_status_success; + } + + // Count the number of changes made. + size_t changes = 0; + + // Loop over the nodes in this subgraph. + for (uint32_t node_id = 0; node_id < subgraph->num_nodes; node_id++) { + struct xnn_node* node = &subgraph->nodes[node_id]; + + // Skip anything that is not a fully-connected node. + switch (node->type) { + case xnn_node_type_binary_elementwise: + // Replace `mul(x, x)` with `sqr(x)` for consistency. + if (node->binary_operator == xnn_binary_multiply && + node->num_inputs == 2 && node->inputs[0] == node->inputs[1]) { + const uint32_t input_id = node->inputs[0]; + const uint32_t output_id = node->outputs[0]; + xnn_log_info( + "Converting node mul[#%u](v%03u, v%03u) to sqr[#%u](v%03u).", + node_id, input_id, input_id, node_id, input_id); + node->type = xnn_node_type_invalid; + enum xnn_status status = xnn_define_unary(subgraph, xnn_unary_square, + /*params=*/NULL, input_id, + output_id, node->flags); + if (status != xnn_status_success) { + xnn_log_error("Failed to create new unary-elementwise node."); + return status; + } + subgraph->nodes[node_id] = subgraph->nodes[--subgraph->num_nodes]; + subgraph->nodes[node_id].id = node_id; + subgraph->values[output_id].producer = node_id; + changes++; + } + + // Replace `div(a, sqrt(b))` with `mul(a, rsqrt(b))` for consistency. + while (node->binary_operator == xnn_binary_divide) { + const uint32_t output_id = node->outputs[0]; + const uint32_t value_a_id = node->inputs[0]; + struct xnn_value* value_sqrt = &subgraph->values[node->inputs[1]]; + if (value_sqrt->producer == XNN_INVALID_NODE_ID) { + break; + } + const uint32_t node_sqrt_id = value_sqrt->producer; + struct xnn_node* node_sqrt = &subgraph->nodes[node_sqrt_id]; + if (node_sqrt->type != xnn_node_type_unary_elementwise || + node_sqrt->unary_operator != xnn_unary_square_root) { + break; + } + const uint32_t value_b_id = node_sqrt->inputs[0]; + uint32_t new_value_id = XNN_INVALID_VALUE_ID; + + // If the `sqrt` node has multiple consumers, create a temporary value + // for the output of the new `rsqrt` node. + if (value_sqrt->num_consumers > 1) { + enum xnn_status status = xnn_define_tensor_value( + subgraph, value_sqrt->datatype, value_sqrt->shape.num_dims, + value_sqrt->shape.dim, value_sqrt->data, + /*external_id=*/XNN_INVALID_VALUE_ID, value_sqrt->flags, + &new_value_id); + if (status != xnn_status_success || + new_value_id == XNN_INVALID_VALUE_ID) { + xnn_log_error("Failed to create new temporary value."); + return status; + } + + // Create a new `rsqrt` node. + status = xnn_define_unary( + subgraph, xnn_unary_reciprocal_square_root, + /*params=*/NULL, + /*input_id=*/value_b_id, new_value_id, node_sqrt->flags); + if (status != xnn_status_success) { + xnn_log_error("Failed to create new `rsqrt` node."); + return status; + } + subgraph->values[new_value_id].producer = subgraph->num_nodes - 1; + subgraph->values[new_value_id].num_consumers = 1; + subgraph->values[new_value_id].first_consumer = node_id; + } else { + new_value_id = value_sqrt->id; + node_sqrt->unary_operator = xnn_unary_reciprocal_square_root; + } + + // Create the new `mul` node. + enum xnn_status status = + xnn_define_binary(subgraph, xnn_binary_multiply, + /*params=*/NULL, value_a_id, new_value_id, + output_id, subgraph->nodes[node_id].flags); + if (status != xnn_status_success) { + xnn_log_error("Failed to create new `rsqrt` node."); + return status; + } + subgraph->nodes[node_id] = subgraph->nodes[--subgraph->num_nodes]; + subgraph->nodes[node_id].id = node_id; + subgraph->values[output_id].producer = node_id; + + node = &subgraph->nodes[node_id]; + xnn_log_info( + "Converted div[#%u](v%03u, sqrt[#%u](v%03u)) to " + "mul[#%u](v%03u, rsqrt[#%u](v%03u)).", + node_id, value_a_id, node_sqrt_id, value_b_id, node_id, + value_a_id, subgraph->num_nodes - 1, value_b_id); + + changes++; + break; + } + + // Looks for RMS normalization subgraphs of the form: + // + // * b = reduce_sum2(a, axis=-1) + // * c = mul(b, inv_n) + // * d = add(c, eps) (optional) + // * e = rsqrt(d) + // * f = mul(a, e) + // + // and converts them to a single `normalize` node (without scaling). + while (node->binary_operator == xnn_binary_multiply) { + // Check node "f", extract values "a" and "e". + struct xnn_node* node_f = node; + const uint32_t output_id = node_f->outputs[0]; + struct xnn_value* value_a = &subgraph->values[node_f->inputs[0]]; + struct xnn_value* value_e = &subgraph->values[node_f->inputs[1]]; + if (xnn_shape_get_last_dim(&value_e->shape) != 1) { + if (xnn_shape_get_last_dim(&value_a->shape) == 1) { + swap_value_pointers(&value_a, &value_e); + } else { + break; + } + } + + // Check node "e", extract value "d". + if (value_e->producer == XNN_INVALID_NODE_ID) { + break; + } + struct xnn_node* node_e = &subgraph->nodes[value_e->producer]; + if (node_e->type != xnn_node_type_unary_elementwise || + node_e->unary_operator != xnn_unary_reciprocal_square_root) { + break; + } + struct xnn_value* value_d = &subgraph->values[node_e->inputs[0]]; + + // Check optional node "d", extract values "c" and "eps". + float epsilon = 0.0f; + struct xnn_value* value_c = value_d; + while (true) { + if (value_d->producer == XNN_INVALID_NODE_ID) { + break; + } + struct xnn_node* node_d = &subgraph->nodes[value_d->producer]; + // while (node_d->type == xnn_node_type_static_reshape) { + // if (subgraph->values[node_d->inputs[0]].producer == + // XNN_INVALID_NODE_ID) { + // break; + // } + // node_d = + // &subgraph + // ->nodes[subgraph->values[node_d->inputs[0]].producer]; + // } + if (node_d->type != xnn_node_type_binary_elementwise || + node_d->binary_operator != xnn_binary_add) { + break; + } + struct xnn_value* value_maybe_c = + &subgraph->values[node_d->inputs[0]]; + struct xnn_value* value_eps = &subgraph->values[node_d->inputs[1]]; + if (xnn_shape_multiply_all_dims(&value_eps->shape) != 1 || + !xnn_value_is_static(value_eps->allocation_type)) { + if (xnn_shape_multiply_all_dims(&value_maybe_c->shape) == 1 && + xnn_value_is_static(value_maybe_c->allocation_type)) { + swap_value_pointers(&value_maybe_c, &value_eps); + } else { + break; + } + } + value_c = value_maybe_c; + epsilon = get_value_as_float(value_eps->data, value_eps->datatype); + break; + } + + // Check node "c", extract values "b" and verify "inv_n". + if (value_c->producer == XNN_INVALID_NODE_ID) { + break; + } + struct xnn_node* node_c = &subgraph->nodes[value_c->producer]; + if (node_c->type != xnn_node_type_binary_elementwise || + node_c->binary_operator != xnn_binary_multiply) { + break; + } + struct xnn_value* value_b = &subgraph->values[node_c->inputs[0]]; + struct xnn_value* value_inv_n = &subgraph->values[node_c->inputs[1]]; + if (xnn_shape_multiply_all_dims(&value_inv_n->shape) != 1 || + !xnn_value_is_static(value_inv_n->allocation_type)) { + if (xnn_shape_multiply_all_dims(&value_b->shape) == 1 && + xnn_value_is_static(value_b->allocation_type)) { + swap_value_pointers(&value_b, &value_inv_n); + } else { + break; + } + } + const float inv_n = + get_value_as_float(value_inv_n->data, value_inv_n->datatype); + float expected_inv_n = 1.0f / xnn_shape_get_last_dim(&value_a->shape); + if (value_inv_n->datatype == xnn_datatype_fp16) { + expected_inv_n = + xnn_float16_to_float(xnn_float16_from_float(expected_inv_n)); + } + if (inv_n != expected_inv_n) { + break; + } + + // Check node "b", verify value "a". + if (value_b->producer == XNN_INVALID_NODE_ID) { + break; + } + struct xnn_node* node_b = &subgraph->nodes[value_b->producer]; + if (node_b->type != xnn_node_type_static_sum_squared) { + break; + } + if (node_b->inputs[0] != value_a->id) { + break; + } + if (node_b->params.reduce.num_reduction_axes != 1 || + node_b->params.reduce.reduction_axes[0] != + value_a->shape.num_dims - 1) { + break; + } + + // If we made it all the way down here, then we have found an + // RMSNorm! + enum xnn_status status = xnn_define_normalize( + subgraph, xnn_norm_rms, value_a->id, + /*scale_id=*/XNN_INVALID_VALUE_ID, output_id, epsilon, + /*flags=*/0); + if (status != xnn_status_success) { + xnn_log_error("Failed to create new `normalize` node."); + return status; + } + subgraph->nodes[node_id] = subgraph->nodes[--subgraph->num_nodes]; + subgraph->nodes[node_id].id = node_id; + subgraph->values[output_id].producer = node_id; + + xnn_log_info( + "Created RMSNorm (#%u) from nodes #%u, #%u, #%u, #%u, #%u with " + "input v%03u, and output v%03u.", + node_id, value_b->producer, value_c->producer, value_d->producer, + value_e->producer, node_id, value_a->id, output_id); + changes++; + break; + } + + // Looks for L2 normalization subgraphs of the form: + // + // * b = reduce_sum2(a, axis=-1) + // * c = add(b, eps) (optional) + // * d = rsqrt(c) + // * e = mul(a, d) + // + // and converts them to a single `normalize` node (without scaling). + while (node->binary_operator == xnn_binary_multiply) { + // Check node "e", extract values "a" and "d". + struct xnn_node* node_e = node; + const uint32_t output_id = node_e->outputs[0]; + struct xnn_value* value_a = &subgraph->values[node_e->inputs[0]]; + struct xnn_value* value_d = &subgraph->values[node_e->inputs[1]]; + if (xnn_shape_get_last_dim(&value_d->shape) != 1) { + if (xnn_shape_get_last_dim(&value_a->shape) == 1) { + swap_value_pointers(&value_a, &value_d); + } else { + break; + } + } + + // Check node "d", extract value "c". + if (value_d->producer == XNN_INVALID_NODE_ID) { + break; + } + struct xnn_node* node_d = &subgraph->nodes[value_d->producer]; + if (node_d->type != xnn_node_type_unary_elementwise || + node_d->unary_operator != xnn_unary_reciprocal_square_root) { + break; + } + struct xnn_value* value_c = &subgraph->values[node_d->inputs[0]]; + + // Check optional node "c", extract values "b" and "eps". + struct xnn_value* value_b = value_c; + float epsilon = 0.0f; + while (true) { + if (value_c->producer == XNN_INVALID_NODE_ID) { + break; + } + struct xnn_node* node_c = &subgraph->nodes[value_c->producer]; + while (node_c->type == xnn_node_type_static_reshape) { + if (subgraph->values[node_c->inputs[0]].producer == + XNN_INVALID_NODE_ID) { + break; + } + node_c = + &subgraph + ->nodes[subgraph->values[node_c->inputs[0]].producer]; + } + if (node_c->type != xnn_node_type_binary_elementwise || + node_c->binary_operator != xnn_binary_add) { + break; + } + struct xnn_value* value_maybe_b = + &subgraph->values[node_c->inputs[0]]; + struct xnn_value* value_eps = &subgraph->values[node_c->inputs[1]]; + if (xnn_shape_multiply_all_dims(&value_eps->shape) != 1 || + !xnn_value_is_static(value_eps->allocation_type)) { + if (xnn_shape_multiply_all_dims(&value_maybe_b->shape) == 1 && + xnn_value_is_static(value_maybe_b->allocation_type)) { + swap_value_pointers(&value_maybe_b, &value_eps); + } else { + break; + } + } + value_b = value_maybe_b; + epsilon = get_value_as_float(value_eps->data, value_eps->datatype); + break; + } + + // Check node "b", verify value "a". + if (value_b->producer == XNN_INVALID_NODE_ID) { + break; + } + struct xnn_node* node_b = &subgraph->nodes[value_b->producer]; + if (node_b->type != xnn_node_type_static_sum_squared) { + break; + } + if (node_b->inputs[0] != value_a->id) { + break; + } + if (node_b->params.reduce.num_reduction_axes != 1 || + node_b->params.reduce.reduction_axes[0] != + value_a->shape.num_dims - 1) { + break; + } + + // If we made it all the way down here, then we have found an + // L2-Norm! + enum xnn_status status = xnn_define_normalize( + subgraph, xnn_norm_l2, value_a->id, + /*scale_id=*/XNN_INVALID_VALUE_ID, output_id, epsilon, + /*flags=*/0); + if (status != xnn_status_success) { + xnn_log_error("Failed to create new `normalize` node."); + return status; + } + subgraph->nodes[node_id] = subgraph->nodes[--subgraph->num_nodes]; + subgraph->nodes[node_id].id = node_id; + subgraph->values[output_id].producer = node_id; + + xnn_log_info( + "Created L2Norm (#%u) from nodes #%u, #%u, #%u, #%u with " + "input v%03u, and output v%03u.", + node_id, value_b->producer, value_c->producer, value_d->producer, + node_id, value_a->id, output_id); + changes++; + break; + } + + // Check for a scaled normalization node. + while (node->binary_operator == xnn_binary_multiply) { + // Check node "h", extract values "g" and "scale". + const uint32_t output_id = node->outputs[0]; + struct xnn_value* value_rms = &subgraph->values[node->inputs[0]]; + struct xnn_value* value_scale = &subgraph->values[node->inputs[1]]; + if (!xnn_value_is_static(value_scale->allocation_type)) { + if (xnn_value_is_static(value_rms->allocation_type)) { + swap_value_pointers(&value_rms, &value_scale); + } else { + break; + } + } + if (xnn_shape_get_last_dim(&value_rms->shape) != + xnn_shape_multiply_all_dims(&value_scale->shape)) { + break; + } + if (value_rms->producer == XNN_INVALID_VALUE_ID) { + break; + } + struct xnn_node* node_rms = &subgraph->nodes[value_rms->producer]; + if (node_rms->type != xnn_node_type_normalize || + node_rms->num_inputs > 1) { + break; + } + + // If we got this far, then this is a scaled normalization. + enum xnn_status status = xnn_define_normalize( + subgraph, node_rms->params.normalize.norm_type, + node_rms->inputs[0], value_scale->id, output_id, + node_rms->params.normalize.epsilon, + /*flags=*/0); + if (status != xnn_status_success) { + xnn_log_error("Failed to create new `normalize` node."); + return status; + } + subgraph->nodes[node_id] = subgraph->nodes[--subgraph->num_nodes]; + subgraph->nodes[node_id].id = node_id; + subgraph->values[output_id].producer = node_id; + + xnn_log_info( + "Created mul[#%u](normalize[#%u](v%03u), v%03u) to " + "normalize[#%u](v%03u) with scaling.", + node_id, value_rms->producer, node_rms->inputs[0], + value_scale->id, node_id, node->inputs[0]); + changes++; + break; + } + break; + + case xnn_node_type_static_sum: + // Convert `reduce_sum(sqr(a))` to `reduce_sum2(a)`. + while (true) { + struct xnn_value* value_arg = &subgraph->values[node->inputs[0]]; + if (!(value_arg->datatype == xnn_datatype_fp16 || + value_arg->datatype == xnn_datatype_fp32) || + value_arg->producer == XNN_INVALID_NODE_ID) { + break; + } + struct xnn_node* node_arg = &subgraph->nodes[value_arg->producer]; + if (node_arg->type != xnn_node_type_unary_elementwise || + node_arg->unary_operator != xnn_unary_square) { + break; + } + const uint32_t output_id = node->outputs[0]; + enum xnn_status status = xnn_define_static_reduce_v2( + subgraph, xnn_reduce_sum_squared, + node->params.reduce.num_reduction_axes, + node->params.reduce.reduction_axes, node_arg->inputs[0], + output_id, node->flags); + if (status != xnn_status_success) { + xnn_log_error("Failed to create new `normalize` node."); + return status; + } + subgraph->nodes[node_id] = subgraph->nodes[--subgraph->num_nodes]; + subgraph->nodes[node_id].id = node_id; + subgraph->values[output_id].producer = node_id; + + xnn_log_info( + "Converted reduce_sum[#%u](sqr[#%u](v%03u)) to " + "reduce_sum2[#%u](v%03u).", + node_id, value_arg->producer, node->inputs[0], node_id, + node->inputs[0]); + changes++; + break; + } + break; + default: + break; + } + } + + // Clean up after ourselves. + if (changes) { + xnn_subgraph_clean_up(subgraph); + } + + return xnn_status_success; +} + enum xnn_status xnn_subgraph_optimize_packed_lhs(xnn_subgraph_t subgraph, uint32_t optimization_flags) { // Count the number of changes made. @@ -2656,6 +3176,13 @@ enum xnn_status xnn_subgraph_optimize(xnn_subgraph_t subgraph, return xnn_status_unsupported_hardware; } + // Apply some common subgraph optimizations. + enum xnn_status status = + xnn_subgraph_optimize_common_subgraphs(subgraph, optimization_flags); + if (status != xnn_status_success) { + return status; + } + if ((optimization_flags & XNN_FLAG_FORCE_FP16_INFERENCE) && (!xnn_is_f16_compatible_config(hardware_config))) { xnn_log_error( @@ -2705,8 +3232,7 @@ enum xnn_status xnn_subgraph_optimize(xnn_subgraph_t subgraph, optimization_flags |= XNN_FLAG_NO_INLINED_LHS_PACKING; } - enum xnn_status status = - xnn_subgraph_optimize_packed_lhs(subgraph, optimization_flags); + status = xnn_subgraph_optimize_packed_lhs(subgraph, optimization_flags); if (status != xnn_status_success) { return status; } @@ -2769,6 +3295,8 @@ enum xnn_node_type xnn_reduce_operator_to_node_type( return xnn_node_type_static_reduce_min; case xnn_reduce_sum: return xnn_node_type_static_sum; + case xnn_reduce_sum_squared: + return xnn_node_type_static_sum_squared; default: return xnn_node_type_invalid; } @@ -2785,6 +3313,8 @@ enum xnn_reduce_operator xnn_node_type_to_reduce_operator( return xnn_reduce_min; case xnn_node_type_static_sum: return xnn_reduce_sum; + case xnn_node_type_static_sum_squared: + return xnn_reduce_sum_squared; default: return xnn_reduce_invalid; } diff --git a/src/subgraph/normalize.c b/src/subgraph/normalize.c new file mode 100644 index 00000000000..bcebbb4e5b5 --- /dev/null +++ b/src/subgraph/normalize.c @@ -0,0 +1,273 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "include/xnnpack.h" +#include "src/xnnpack/common.h" +#include "src/xnnpack/log.h" +#include "src/xnnpack/node-type.h" +#include "src/xnnpack/operator-type.h" +#include "src/xnnpack/operator-utils.h" +#include "src/xnnpack/operator.h" +#include "src/xnnpack/reshape-helpers.h" +#include "src/xnnpack/subgraph-validation.h" +#include "src/xnnpack/subgraph.h" +#include + +static enum xnn_status create_normalize_operator( + const struct xnn_node* node, const struct xnn_runtime_value* values, + size_t num_values, struct xnn_operator_data* opdata, + xnn_weights_cache_t weights_cache) { + assert(node->num_inputs == 1 || node->num_inputs == 2); + assert(node->num_outputs == 1); + + const uint32_t input_id = node->inputs[0]; + assert(input_id != XNN_INVALID_VALUE_ID); + assert(input_id < num_values); + const struct xnn_runtime_value* input_value = &values[input_id]; + enum xnn_status status; + switch (input_value->datatype) { + case xnn_datatype_fp32: + status = xnn_create_normalize_nc_f32( + node->params.normalize.norm_type, node->params.normalize.epsilon, + node->flags, &opdata->operator_objects[0]); + break; + case xnn_datatype_fp16: + status = xnn_create_normalize_nc_f16( + node->params.normalize.norm_type, node->params.normalize.epsilon, + node->flags, &opdata->operator_objects[0]); + break; + default: + XNN_UNREACHABLE; + } + return status; +} + +static enum xnn_status reshape_normalize_operator( + struct xnn_operator_data* opdata, struct xnn_runtime_value* values, + size_t num_values, pthreadpool_t threadpool) { + const uint32_t input_id = opdata->inputs[0]; + assert(input_id < num_values); + + const size_t num_input_dims = values[input_id].shape.num_dims; + assert(num_input_dims > 0); + const size_t channel_dim = values[input_id].shape.dim[num_input_dims - 1]; + const size_t batch_size = + xnn_shape_multiply_non_channel_dims(&values[input_id].shape); + + if (opdata->num_inputs > 1 && opdata->inputs[1] != XNN_INVALID_VALUE_ID) { + const uint32_t scale_id = opdata->inputs[1]; + const struct xnn_runtime_value* scale_value = &values[scale_id]; + if (xnn_shape_get_last_dim(&scale_value->shape) != channel_dim || + xnn_shape_multiply_all_dims(&scale_value->shape) != channel_dim) { + xnn_log_error("Bad scale size, expected %zu, got %zu.", channel_dim, + xnn_shape_get_last_dim(&scale_value->shape)); + return xnn_status_invalid_parameter; + } + } + + const size_t old_workspace_size = opdata->workspace_size; + enum xnn_status status = xnn_status_invalid_state; + switch (opdata->operator_objects[0]->type) { + case xnn_operator_type_normalize_nc_f32: + status = xnn_reshape_normalize_nc_f32(opdata->operator_objects[0], + /*channels=*/channel_dim, + /*input_stride=*/channel_dim, + /*output_stride=*/channel_dim, + batch_size, threadpool); + break; + case xnn_operator_type_normalize_nc_f16: + status = xnn_reshape_normalize_nc_f16(opdata->operator_objects[0], + /*channels=*/channel_dim, + /*input_stride=*/channel_dim, + /*output_stride=*/channel_dim, + batch_size, threadpool); + break; + default: + XNN_UNREACHABLE; + } + if (status != xnn_status_success) { + return status; + } + return resize_unary_elementwise_output_tensor(opdata, values, num_values, + old_workspace_size, threadpool); +} + +static enum xnn_status setup_normalize_operator( + const struct xnn_operator_data* opdata, + const struct xnn_runtime_value* values, size_t num_values, + pthreadpool_t threadpool) { + const uint32_t input_id = opdata->inputs[0]; + assert(input_id != XNN_INVALID_VALUE_ID); + assert(input_id < num_values); + + const uint32_t output_id = opdata->outputs[0]; + assert(output_id != XNN_INVALID_VALUE_ID); + assert(output_id < num_values); + + const struct xnn_runtime_value* input_value = values + input_id; + const void* input_data = input_value->data; + assert(input_data != NULL); + + const struct xnn_runtime_value* output_value = values + output_id; + void* output_data = output_value->data; + assert(output_data != NULL); + + const void* scale_data = NULL; + if (opdata->num_inputs == 2) { + const uint32_t scale_id = opdata->inputs[1]; + assert(scale_id != XNN_INVALID_VALUE_ID); + assert(scale_id < num_values); + const struct xnn_runtime_value* scale_value = values + scale_id; + scale_data = scale_value->data; + assert(scale_data != NULL); + } + + switch (opdata->operator_objects[0]->type) { + case xnn_operator_type_normalize_nc_f32: + return xnn_setup_normalize_nc_f32(opdata->operator_objects[0], input_data, + scale_data, output_data); + case xnn_operator_type_normalize_nc_f16: + return xnn_setup_normalize_nc_f16(opdata->operator_objects[0], input_data, + scale_data, output_data); + default: + XNN_UNREACHABLE; + } +} + +enum xnn_status xnn_define_normalize(xnn_subgraph_t subgraph, + enum xnn_norm_type norm_type, + uint32_t input_id, uint32_t scale_id, + uint32_t output_id, float epsilon, + uint32_t flags) { + enum xnn_status status; + if ((status = xnn_subgraph_check_xnnpack_initialized( + xnn_node_type_normalize)) != xnn_status_success) { + return status; + } + + switch (norm_type) { + case xnn_norm_l2: + case xnn_norm_rms: + break; + default: + xnn_log_error("failed to define %s operator with input ID #%" PRIu32 + ": invalid norm type %i (%s)", + xnn_node_type_to_string(xnn_node_type_normalize), input_id, + norm_type, xnn_norm_type_to_string(norm_type)); + } + + if (epsilon < 0) { + xnn_log_error("failed to define %s operator with input ID #%" PRIu32 + ": value of epsilon should be positive (got %e)", + xnn_node_type_to_string(xnn_node_type_normalize), input_id, + epsilon); + } + + if ((status = xnn_subgraph_check_input_node_id( + xnn_node_type_normalize, input_id, subgraph->num_values)) != + xnn_status_success) { + return status; + } + + const struct xnn_value* input_value = &subgraph->values[input_id]; + status = xnn_subgraph_check_input_type_dense(xnn_node_type_normalize, + input_id, input_value); + if (status != xnn_status_success) { + return status; + } + + switch (input_value->datatype) { + case xnn_datatype_fp16: + case xnn_datatype_fp32: + break; + default: + xnn_log_error("failed to define %s operator with input ID #%" PRIu32 + ": unsupported Value datatype %s (%d)", + xnn_node_type_to_string(xnn_node_type_normalize), input_id, + xnn_datatype_to_string(input_value->datatype), + input_value->datatype); + return xnn_status_invalid_parameter; + } + + if (scale_id != XNN_INVALID_VALUE_ID) { + if ((status = xnn_subgraph_check_input_node_id( + xnn_node_type_normalize, scale_id, subgraph->num_values)) != + xnn_status_success) { + return status; + } + + const struct xnn_value* scale_value = &subgraph->values[scale_id]; + status = xnn_subgraph_check_input_type_dense(xnn_node_type_normalize, + scale_id, scale_value); + if (status != xnn_status_success) { + return status; + } + + switch (scale_value->datatype) { + case xnn_datatype_fp16: + case xnn_datatype_fp32: + break; + default: + xnn_log_error("failed to define %s operator with scale ID #%" PRIu32 + ": unsupported Value datatype %s (%d)", + xnn_node_type_to_string(xnn_node_type_normalize), + scale_id, xnn_datatype_to_string(scale_value->datatype), + scale_value->datatype); + return xnn_status_invalid_parameter; + } + } + + status = xnn_subgraph_check_output_node_id(xnn_node_type_normalize, output_id, + subgraph->num_values); + if (status != xnn_status_success) { + return status; + } + + const struct xnn_value* output_value = &subgraph->values[output_id]; + status = xnn_subgraph_check_output_type_dense(xnn_node_type_normalize, + output_id, output_value); + if (status != xnn_status_success) { + return status; + } + + switch (output_value->datatype) { + case xnn_datatype_fp16: + case xnn_datatype_fp32: + break; + default: + xnn_log_error("failed to define %s operator with output ID #%" PRIu32 + ": unsupported Value datatype %s (%d)", + xnn_node_type_to_string(xnn_node_type_normalize), output_id, + xnn_datatype_to_string(output_value->datatype), + output_value->datatype); + return xnn_status_invalid_parameter; + } + + struct xnn_node* node = xnn_subgraph_new_node(subgraph); + if (node == NULL) { + return xnn_status_out_of_memory; + } + + node->type = xnn_node_type_normalize; + node->params.normalize.epsilon = epsilon; + node->params.normalize.norm_type = norm_type; + node->num_inputs = scale_id == XNN_INVALID_VALUE_ID ? 1 : 2; + node->inputs[0] = input_id; + node->inputs[1] = scale_id; + node->num_outputs = 1; + node->outputs[0] = output_id; + node->flags = flags; + + node->create = create_normalize_operator; + node->reshape = reshape_normalize_operator; + node->setup = setup_normalize_operator; + + return xnn_status_success; +} diff --git a/src/subgraph/static-reduce.c b/src/subgraph/static-reduce.c index 8d2339d7ed0..1c673adcddf 100644 --- a/src/subgraph/static-reduce.c +++ b/src/subgraph/static-reduce.c @@ -292,8 +292,18 @@ enum xnn_status xnn_define_static_reduce_v2( switch (input_value->datatype) { case xnn_datatype_fp16: case xnn_datatype_fp32: + break; case xnn_datatype_qint8: case xnn_datatype_quint8: + if (reduce_operator == xnn_reduce_sum_squared) { + xnn_log_error( + "failed to define %s operator with the first input ID #%" PRIu32 + ": unsupported Value datatype %s (%d)", + xnn_node_type_to_string(node_type), input_id, + xnn_datatype_to_string(input_value->datatype), + input_value->datatype); + return xnn_status_invalid_parameter; + } break; default: xnn_log_error( diff --git a/src/tensor.c b/src/tensor.c index e3b6f6e7c0f..4d8f9b0ddad 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -15,6 +15,7 @@ #include "include/xnnpack.h" #include "src/xnnpack/allocation-type.h" #include "src/xnnpack/common.h" +#include "src/xnnpack/config-types.h" #include "src/xnnpack/datatype.h" #include "src/xnnpack/log.h" #include "src/xnnpack/math.h" @@ -656,6 +657,10 @@ size_t xnn_shape_multiply_trailing_dims( return product; } +size_t xnn_shape_get_last_dim(const struct xnn_shape* shape) { + return shape->num_dims ? shape->dim[shape->num_dims - 1] : 1; +} + size_t get_tensor_size(const struct xnn_gemm_config* gemm_config, enum xnn_value_type type, enum xnn_datatype datatype, const struct xnn_shape *shape, uint32_t flags) { assert(type == xnn_value_type_dense_tensor); diff --git a/src/xnnpack/buffer.h b/src/xnnpack/buffer.h index ed0164035d8..b07fef25319 100644 --- a/src/xnnpack/buffer.h +++ b/src/xnnpack/buffer.h @@ -128,6 +128,7 @@ template T get_reduce_identity(xnn_reduce_operator op) { switch (op) { case xnn_reduce_sum: + case xnn_reduce_sum_squared: case xnn_reduce_mean: return 0; case xnn_reduce_max: @@ -760,6 +761,15 @@ class Tensor { generate([=]() { return value; }); } + template + Tensor cast() const { + Tensor res(shape()); + for (size_t k = 0; k < res.size(); k++) { + res[k] = static_cast((*this)[k]); + } + return res; + } + private: static void copy_impl(size_t rank, const size_t* extents, const size_t* src_strides, const T* src, diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index bc23c91fc71..4a4dac1c204 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -1348,6 +1348,34 @@ struct floating_point_softmax_context { XNN_PRIVATE void xnn_compute_floating_point_softmax( struct floating_point_softmax_context* context, size_t batch_index); +typedef void (*xnn_convert_scale_fn)(float input, void* output); + +struct normalize_context { + size_t n; + const void* x; + size_t x_stride; + void* y; + size_t y_stride; + float epsilon; + size_t num_channels; + xnn_rsum2_ukernel_fn rsum2_ukernel; + xnn_vbinary_ukernel_fn vmulc_ukernel; + xnn_vbinary_ukernel_fn vmul_ukernel; + xnn_convert_scale_fn convert_scale; + union { + struct xnn_f16_minmax_params f16; + struct xnn_f32_minmax_params f32; + } minmax_params; + union { + struct xnn_f32_scale_params f32; + } rsum2_params; + const void* scale; + enum xnn_norm_type norm_type; +}; + +XNN_PRIVATE void xnn_compute_normalize(struct normalize_context* context, + size_t batch_index); + struct rope_context { size_t scaled_channels; size_t batch_stride; diff --git a/src/xnnpack/config.h b/src/xnnpack/config.h index 406398c3c3e..6584c1da78c 100644 --- a/src/xnnpack/config.h +++ b/src/xnnpack/config.h @@ -205,6 +205,7 @@ XNN_INTERNAL const struct xnn_unary_elementwise_config* xnn_init_xx_copy_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f16_f32acc_rsum_config(); +XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f16_f32acc_rsum2_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f16_rmax_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f16_rminmax_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f16_rmin_config(); @@ -212,6 +213,7 @@ XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f32_rmax_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f32_rminmax_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f32_rmin_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f32_rsum_config(); +XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f32_rsum2_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_s8_rmax_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_s8_rminmax_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_s8_rmin_config(); diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index 1de37832f2c..ed19f0a6445 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -780,6 +780,21 @@ typedef void (*xnn_qu8_rdsum_ukernel_fn)( size_t rows, size_t channels, const uint8_t* input, size_t input_stride, const uint8_t* zero, uint32_t* output, const struct xnn_qs8_rsum_params* params); + +// RDSUM2: Discontiguous Reduce-Sum squared + +typedef void (*xnn_f16_f32acc_rdsum2_ukernel_fn)( + size_t channels, size_t k1, size_t k2, size_t k3, const xnn_float16* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const xnn_float16* zero, float* output, + const struct xnn_f16_f32acc_scale_params* params); + +typedef void (*xnn_f32_rdsum2_ukernel_fn)( + size_t channels, size_t k1, size_t k2, size_t k3, const float* input, + size_t input_stride1, size_t input_stride2, size_t input_stride3, + const float* zero, float* output, + const struct xnn_f32_scale_params* params); + // RSUM: Reduce-Sum typedef void (*xnn_f16_rsum_ukernel_fn)( @@ -802,6 +817,19 @@ typedef void (*xnn_qu8_rsum_ukernel_fn)( size_t batch, const uint8_t* input, uint32_t* output, const struct xnn_qs8_rsum_params* params); +// RSUM2: Reduce-Sum Squared + +typedef void (*xnn_rsum2_ukernel_fn)(size_t batch, const void* input, + void* output, const void* params); + +typedef void (*xnn_f16_f32acc_rsum2_ukernel_fn)( + size_t batch, const xnn_float16* input, float* output, + const struct xnn_f16_f32acc_scale_params* params); + +typedef void (*xnn_f32_rsum2_ukernel_fn)( + size_t batch, const float* input, float* output, + const struct xnn_f32_scale_params* params); + // RDMINMAX: Discontiguous Reduce-MINMAX typedef void (*xnn_reduce_discontiguous_ukernel_fn)( diff --git a/src/xnnpack/node-type-defs.inc b/src/xnnpack/node-type-defs.inc index d460bfffc6d..94d4bbc8b33 100644 --- a/src/xnnpack/node-type-defs.inc +++ b/src/xnnpack/node-type-defs.inc @@ -23,8 +23,9 @@ XNN_ENUM_ITEM(xnn_node_type_depth_to_space_2d, "Depth To Space 2D") XNN_ENUM_ITEM(xnn_node_type_depthwise_convolution_2d, "Depthwise Convolution 2D") XNN_ENUM_ITEM(xnn_node_type_even_split, "Even Split") -XNN_ENUM_ITEM(xnn_node_type_fully_connected, "Fully Connected") XNN_ENUM_ITEM(xnn_node_type_fully_connected_sparse, "Fully Connected Sparse") +XNN_ENUM_ITEM(xnn_node_type_fully_connected, "Fully Connected") +XNN_ENUM_ITEM(xnn_node_type_fuse_dims, "Fuse Dims") XNN_ENUM_ITEM(xnn_node_type_global_average_pooling_1d, "Global Average Pooling 1D") XNN_ENUM_ITEM(xnn_node_type_global_average_pooling_2d, @@ -32,15 +33,15 @@ XNN_ENUM_ITEM(xnn_node_type_global_average_pooling_2d, XNN_ENUM_ITEM(xnn_node_type_global_sum_pooling_1d, "Global Sum Pooling 1D") XNN_ENUM_ITEM(xnn_node_type_global_sum_pooling_2d, "Global Sum Pooling 2D") XNN_ENUM_ITEM(xnn_node_type_max_pooling_2d, "Max Pooling 2D") +XNN_ENUM_ITEM(xnn_node_type_normalize, "Normalize") XNN_ENUM_ITEM(xnn_node_type_pack_lh, "Pack LH") XNN_ENUM_ITEM(xnn_node_type_rope, "RoPE") XNN_ENUM_ITEM(xnn_node_type_softmax, "Softmax") XNN_ENUM_ITEM(xnn_node_type_space_to_depth_2d, "Space To Depth 2D") -XNN_ENUM_ITEM(xnn_node_type_static_constant_pad, "Static Constant Pad") -XNN_ENUM_ITEM(xnn_node_type_static_expand_dims, "Static Expand Dims") -XNN_ENUM_ITEM(xnn_node_type_fuse_dims, "Fuse Dims") XNN_ENUM_ITEM(xnn_node_type_split_dims, "Split Dims") XNN_ENUM_ITEM(xnn_node_type_static_broadcast, "Static Broadcast") +XNN_ENUM_ITEM(xnn_node_type_static_constant_pad, "Static Constant Pad") +XNN_ENUM_ITEM(xnn_node_type_static_expand_dims, "Static Expand Dims") XNN_ENUM_ITEM(xnn_node_type_static_mean, "Static Mean") XNN_ENUM_ITEM(xnn_node_type_static_reduce_max, "Static Reduce Max") XNN_ENUM_ITEM(xnn_node_type_static_reduce_min, "Static Reduce Min") @@ -48,6 +49,7 @@ XNN_ENUM_ITEM(xnn_node_type_static_reshape, "Static Reshape") XNN_ENUM_ITEM(xnn_node_type_static_resize_bilinear_2d, "Static Resize Bilinear 2D") XNN_ENUM_ITEM(xnn_node_type_static_slice, "Static Slice") +XNN_ENUM_ITEM(xnn_node_type_static_sum_squared, "Static Sum Squared") XNN_ENUM_ITEM(xnn_node_type_static_sum, "Static Sum") XNN_ENUM_ITEM(xnn_node_type_static_transpose, "Static Transpose") XNN_ENUM_ITEM(xnn_node_type_unary_elementwise, "Unary Elementwise") diff --git a/src/xnnpack/operator-type-defs.inc b/src/xnnpack/operator-type-defs.inc index 1891fd64a60..829c268ea59 100644 --- a/src/xnnpack/operator-type-defs.inc +++ b/src/xnnpack/operator-type-defs.inc @@ -154,6 +154,8 @@ XNN_ENUM_ITEM(xnn_operator_type_max_pooling_nhwc_f32, "Max Pooling (NHWC, F32)") XNN_ENUM_ITEM(xnn_operator_type_max_pooling_nhwc_s8, "Max Pooling (NHWC, S8)") XNN_ENUM_ITEM(xnn_operator_type_max_pooling_nhwc_u8, "Max Pooling (NHWC, U8)") XNN_ENUM_ITEM(xnn_operator_type_mean_nd, "Mean (ND)") +XNN_ENUM_ITEM(xnn_operator_type_normalize_nc_f16, "RMS Norm (NC, F16)") +XNN_ENUM_ITEM(xnn_operator_type_normalize_nc_f32, "RMS Norm (NC, F32)") XNN_ENUM_ITEM(xnn_operator_type_pack_lh_x8, "Pack LH (X8)") XNN_ENUM_ITEM(xnn_operator_type_pack_lh_x16, "Pack LH (X16)") XNN_ENUM_ITEM(xnn_operator_type_pack_lh_x32, "Pack LH (X32)") @@ -178,6 +180,7 @@ XNN_ENUM_ITEM(xnn_operator_type_space_to_depth_nhwc_x16, XNN_ENUM_ITEM(xnn_operator_type_space_to_depth_nhwc_x32, "Space To Depth (NHWC, X32)") XNN_ENUM_ITEM(xnn_operator_type_sum_nd, "Sum (ND)") +XNN_ENUM_ITEM(xnn_operator_type_sum_squared_nd, "Normalize (ND)") XNN_ENUM_ITEM(xnn_operator_type_transpose_nd_x8, "Transpose (ND, X8)") XNN_ENUM_ITEM(xnn_operator_type_transpose_nd_x16, "Transpose (ND, X16)") XNN_ENUM_ITEM(xnn_operator_type_transpose_nd_x32, "Transpose (ND, X32)") diff --git a/src/xnnpack/operator-utils.h b/src/xnnpack/operator-utils.h index 76d030dfb73..ab2ad6d85ef 100644 --- a/src/xnnpack/operator-utils.h +++ b/src/xnnpack/operator-utils.h @@ -9,7 +9,9 @@ #include #include +#include "include/xnnpack.h" #include "src/xnnpack/common.h" +#include "src/xnnpack/microfnptr.h" #include "src/xnnpack/operator.h" static inline bool use_weights_cache(struct xnn_operator* op) { @@ -59,6 +61,7 @@ XNN_INTERNAL enum xnn_status xnn_allocate_extra_params( xnn_operator_t op, size_t num_extra_params); XNN_INTERNAL enum xnn_status xnn_destroy_operator(xnn_operator_t op); +XNN_INTERNAL const char* xnn_norm_type_to_string(enum xnn_norm_type norm_type); XNN_INTERNAL const char* xnn_unary_operator_to_string( enum xnn_unary_operator op); XNN_INTERNAL const char* xnn_binary_operator_to_string( @@ -70,4 +73,4 @@ XNN_INTERNAL const char* xnn_operator_type_to_string_v2(xnn_operator_t op); } #endif -#endif // XNNPACK_SRC_XNNPACK_OPERATOR_UTILS_H_ \ No newline at end of file +#endif // XNNPACK_SRC_XNNPACK_OPERATOR_UTILS_H_ diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h index 00e5610aa92..8bc6e2d41ee 100644 --- a/src/xnnpack/operator.h +++ b/src/xnnpack/operator.h @@ -21,7 +21,6 @@ #include "src/xnnpack/microparams.h" #include "src/xnnpack/node-type.h" #include "src/xnnpack/operator-type.h" -#include "src/xnnpack/pack.h" #include #ifdef __cplusplus @@ -302,8 +301,14 @@ struct xnn_operator { const struct xnn_lut32norm_config* lut32norm_config; // For F16 and F32. struct { - const struct xnn_raddstoreexpminusmax_config* - raddstoreexpminusmax_config; + union { + const struct xnn_raddstoreexpminusmax_config* + raddstoreexpminusmax_config; + struct { + float normalize_epsilon; + enum xnn_norm_type norm_type; + }; + }; const struct xnn_binary_elementwise_config* vmul_config; }; }; @@ -352,6 +357,7 @@ struct xnn_operator { struct subgemm_context subgemm; struct transpose_context transpose; struct floating_point_softmax_context floating_point_softmax; + struct normalize_context normalize; struct u8_softmax_context u8_softmax; struct f16_qd8_convert_context f16_qd8_convert; struct f32_qd8_convert_context f32_qd8_convert; diff --git a/src/xnnpack/reduce.h b/src/xnnpack/reduce.h index ba1dc0befac..cde811f9448 100644 --- a/src/xnnpack/reduce.h +++ b/src/xnnpack/reduce.h @@ -21,6 +21,7 @@ extern "C" { XNN_INTERNAL void ukernel(size_t batch, const datatype_in* input, \ datatype_out* output, const params_type* params); #include "src/f16-f32acc-rsum/f16-f32acc-rsum.inc" +#include "src/f16-f32acc-rsum2/f16-f32acc-rsum2.inc" #include "src/f16-rminmax/f16-rmax.inc" #include "src/f16-rminmax/f16-rmin.inc" #include "src/f16-rminmax/f16-rminmax.inc" @@ -29,6 +30,7 @@ extern "C" { #include "src/f32-rminmax/f32-rmin.inc" #include "src/f32-rminmax/f32-rminmax.inc" #include "src/f32-rsum/f32-rsum.inc" +#include "src/f32-rsum2/f32-rsum2.inc" #include "src/qs8-rsum/qs8-rsum.inc" #include "src/qu8-rsum/qu8-rsum.inc" #undef XNN_UKERNEL @@ -53,7 +55,9 @@ extern "C" { const datatype_in* zero, datatype_out* output, \ const params_type* params); #include "src/f16-f32acc-rdsum/f16-f32acc-rdsum.inc" +#include "src/f16-f32acc-rdsum2/f16-f32acc-rdsum2.inc" #include "src/f32-rdsum/f32-rdsum.inc" +#include "src/f32-rdsum2/f32-rdsum2.inc" #undef XNN_UKERNEL #define XNN_UKERNEL(arch_flags, ukernel, row_tile, batch_tile, vector_tile, \ diff --git a/src/xnnpack/subgraph.h b/src/xnnpack/subgraph.h index eb232c704a7..1a1af5652d5 100644 --- a/src/xnnpack/subgraph.h +++ b/src/xnnpack/subgraph.h @@ -355,6 +355,10 @@ struct xnn_node { uint32_t dilation_height; uint32_t dilation_width; } pooling_2d; + struct { + float epsilon; + enum xnn_norm_type norm_type; + } normalize; struct { size_t pre_paddings[XNN_MAX_TENSOR_DIMS]; size_t post_paddings[XNN_MAX_TENSOR_DIMS]; @@ -580,6 +584,9 @@ size_t xnn_shape_multiply_leading_dims(const struct xnn_shape* shape, size_t xnn_shape_multiply_trailing_dims(const struct xnn_shape* shape, size_t start_dim); +// Returns the innermost dimension. +size_t xnn_shape_get_last_dim(const struct xnn_shape* shape); + // Get the size in bytes to hold dynamic quant params size_t xnn_tensor_get_dynamic_quant_param_size(enum xnn_datatype datatype, const struct xnn_shape* shape, diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 87fb8acf830..54b0f9d6508 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -272,8 +272,10 @@ xnnpack_cxx_library( ) for (test, shard_count) in [ ("rminmax", 1), ("rsum", 1), + ("rsum2", 1), ("rdminmax", 1), ("rdsum", 5), + ("rdsum2", 5), ]] [xnnpack_unit_test( diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 72d555b5d13..cce34cbb3b7 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -118,8 +118,10 @@ SET(MICROKERNEL_UNIT_TESTS qu8-vlrelu rdminmax rdsum + rdsum2 rminmax rsum + rsum2 spmm-minmax s8-ibilinear u8-ibilinear @@ -154,6 +156,7 @@ SHARD_TEST(maxpool-minmax-test 10) IF(XNNPACK_TARGET_PROCESSOR MATCHES "^riscv") SHARD_TEST(f32-argmaxpool-test 20) SHARD_TEST(rdsum-test 10) + SHARD_TEST(rdsum2-test 10) ENDIF() SET(MICROKERNEL_DWCONV_UNIT_TESTS diff --git a/test/operators/reduce-nd.cc b/test/operators/reduce-nd.cc index ed997de2435..fd7b1474f36 100644 --- a/test/operators/reduce-nd.cc +++ b/test/operators/reduce-nd.cc @@ -317,6 +317,10 @@ class ReduceOperatorTester { static_cast( input[input_idx]); break; + case xnn_reduce_sum_squared: { + typename Config::AccumulatorType x = input[input_idx]; + accumulator[output_idx] += x * x; + } break; case xnn_reduce_max: accumulator[output_idx] = std::max( accumulator[output_idx], @@ -492,6 +496,9 @@ struct TestParam { case xnn_reduce_sum: sstr << "sum"; break; + case xnn_reduce_sum_squared: + sstr << "sum_squared"; + break; case xnn_reduce_invalid: sstr << "invalid"; break; @@ -589,7 +596,8 @@ TEST_P(ReduceNDTest, reduce) { std::vector GenerateTests() { std::vector params; for (enum xnn_reduce_operator operation : - {xnn_reduce_sum, xnn_reduce_mean, xnn_reduce_max, xnn_reduce_min}) { + {xnn_reduce_sum, xnn_reduce_sum_squared, xnn_reduce_mean, xnn_reduce_max, + xnn_reduce_min}) { for (enum xnn_datatype datatype : {xnn_datatype_fp16, xnn_datatype_fp32, xnn_datatype_qint8, xnn_datatype_quint8}) { diff --git a/test/rdsum2.cc b/test/rdsum2.cc new file mode 100644 index 00000000000..e2074051431 --- /dev/null +++ b/test/rdsum2.cc @@ -0,0 +1,365 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include +#include "src/xnnpack/buffer.h" +#include "src/xnnpack/common.h" // IWYU pragma: keep +#include "src/xnnpack/hardware-config.h" // IWYU pragma: keep +#include "src/xnnpack/isa-checks.h" +#include "src/xnnpack/microfnptr.h" +#include "src/xnnpack/microparams-init.h" // IWYU pragma: keep +#include "src/xnnpack/reduce.h" // IWYU pragma: keep +#include "test/replicable_random_device.h" + +struct Kernel; + +struct LoopInfo { + size_t begin = 0; + size_t end = 0; + size_t step = 1; + + LoopInfo() = default; + LoopInfo(size_t once) : begin(once), end(once + 1), step(1) {} + LoopInfo(size_t begin, size_t end, size_t step = 1) + : begin(begin), end(end), step(step) {} +}; + +class Tester { + public: + Tester& channels(LoopInfo value) { + channels_ = value; + return *this; + } + LoopInfo channels() const { return channels_; } + + Tester& rows(LoopInfo value) { + rows_ = value; + return *this; + } + + LoopInfo rows() const { return rows_; } + + Tester& input_stride1(size_t value) { + input_stride1_ = value; + return *this; + } + + size_t input_stride1() const { return input_stride1_; } + + Tester& scale(float scale) { + scale_ = scale; + return *this; + } + float scale() const { return scale_; } + + // Type deduction helper. + template + using UKernelFn = void (*)(size_t, size_t, size_t, size_t, const Input*, + size_t, size_t, size_t, const Input*, Output*, + const Params*); + + template + void Test(UKernelFn ukernel, + InitParams init_params) const { + xnnpack::ReplicableRandomDevice rng; + std::uniform_int_distribution size_dist(1, 2); + const size_t k2 = size_dist(rng); + const size_t k3 = size_dist(rng); + for (size_t channels = this->channels().begin; + channels < this->channels().end; channels += this->channels().step) { + const size_t input_stride1 = + input_stride1_ == -1 ? channels : input_stride1_; + xnnpack::Buffer zero(channels, 0, xnnpack::XnnExtraBytes); + for (size_t rows = this->rows().begin; rows < this->rows().end; + rows += this->rows().step) { + const size_t input_stride2 = k2 == 1 ? 0 : input_stride1 * rows; + const size_t input_stride3 = k3 == 1 ? 0 : input_stride2 * k2; + xnnpack::Buffer input(rows * input_stride1 * k2 * k3 + channels, + xnnpack::XnnExtraBytes); + xnnpack::Buffer output(channels); + + const float max_abs_value = 10.0f; + xnnpack::DatatypeGenerator input_gen(-max_abs_value, + max_abs_value); + std::generate_n(input.data(), input.size(), + [&]() { return input_gen(rng); }); + xnnpack::DatatypeGenerator output_gen(-max_abs_value, + max_abs_value); + std::generate_n(output.data(), output.size(), + [&]() { return output_gen(rng); }); + + xnnpack::Buffer expected(channels, static_cast(0)); + for (size_t r3 = 0; r3 < k3; ++r3) { + for (size_t r2 = 0; r2 < k2; ++r2) { + for (size_t r1 = 0; r1 < rows; ++r1) { + const Input* input_row = input.data() + r1 * input_stride1 + + r2 * input_stride2 + r3 * input_stride3; + for (size_t c = 0; c < channels; ++c) { + const Output x = input_row[c]; + expected[c] += x * x; + } + } + } + } + + // Note accumulation with output happens after scale. + const float scale = init_params ? this->scale() : 1.0f; + for (size_t c = 0; c < channels; ++c) { + expected[c] *= scale; + expected[c] += output[c]; + } + + Params params; + if (init_params) { + init_params(¶ms, scale); + } + + ukernel(channels, rows, k2, k3, input.data(), + input_stride1 * sizeof(Input), input_stride2 * sizeof(Input), + input_stride3 * sizeof(Input), zero.data(), output.data(), + ¶ms); + + const float tolerance = channels * rows * k2 * k3 * max_abs_value * + max_abs_value * scale * 2.0f * + xnnpack::NumericLimits::epsilon(); + for (size_t c = 0; c < channels; ++c) { + ASSERT_NEAR(expected[c], output[c], tolerance); + } + } + } + } + + void Test(const Kernel& kernel) const; + + private: + LoopInfo channels_; + LoopInfo rows_; + size_t input_stride1_ = -1; + float scale_ = 1.0f; +}; + +struct Kernel { + explicit Kernel(xnn_f16_f32acc_rdsum2_ukernel_fn fn, + xnn_init_f16_f32acc_scale_params_fn init_params) { + dispatch = [=](const Tester& tester) { tester.Test(fn, init_params); }; + } + explicit Kernel(xnn_f32_rdsum2_ukernel_fn fn, + xnn_init_f32_scale_params_fn init_params) { + dispatch = [=](const Tester& tester) { tester.Test(fn, init_params); }; + } + std::function dispatch; +}; + +void Tester::Test(const Kernel& kernel) const { kernel.dispatch(*this); } + +struct KernelInfo { + const char* name; + uint64_t arch_flags; + Kernel kernel; + size_t row_tile; + size_t channel_tile; + bool vector_tile; + size_t elem_size; +}; + +KernelInfo kernels[] = { +#define XNN_UKERNEL(arch_flags, ukernel, row_tile, channel_tile, vector_tile, \ + datatype_in, datatype_out, params_type, init_params) \ + {#ukernel, arch_flags, Kernel{ukernel, init_params}, row_tile, \ + channel_tile, vector_tile, sizeof(datatype_in)}, +#include "src/f16-f32acc-rdsum2/f16-f32acc-rdsum2.inc" +#include "src/f32-rdsum2/f32-rdsum2.inc" +#undef XNN_UKERNEL +}; + +class Test : public testing::TestWithParam {}; + +INSTANTIATE_TEST_SUITE_P( + rdsum2, Test, testing::ValuesIn(kernels), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST_P(Test, channels_eq_2pass_fulltile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels(param.channel_tile) + .rows(param.row_tile * 2) + .Test(param.kernel); +} + +TEST_P(Test, channels_eq_2pass_fulltile_with_input_stride) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels(param.channel_tile) + .rows(param.row_tile * 2) + .input_stride1(param.channel_tile + 5) + .Test(param.kernel); +} + +TEST_P(Test, channels_eq_2pass_subtile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels(param.channel_tile) + .rows({1, param.row_tile * 2}) + .Test(param.kernel); +} + +TEST_P(Test, channels_eq_2pass_subtile_with_input_stride) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels(param.channel_tile) + .rows({1, param.row_tile * 2}) + .input_stride1(param.channel_tile + 5) + .Test(param.kernel); +} + +TEST_P(Test, channels_eq_multipass_fulltile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels(param.channel_tile) + .rows({param.row_tile, param.row_tile * 4, param.row_tile}) + .Test(param.kernel); +} + +TEST_P(Test, channels_eq_multipass_fulltile_with_input_stride) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels(param.channel_tile) + .rows({param.row_tile, param.row_tile * 4, param.row_tile}) + .input_stride1(param.channel_tile + 5) + .Test(param.kernel); +} + +TEST_P(Test, channels_div_2pass_fulltile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels(param.channel_tile * 2) + .rows(param.row_tile * 2) + .Test(param.kernel); +} + +TEST_P(Test, channels_div_2pass_subtile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels(param.channel_tile * 2) + .rows({1, param.row_tile * 2}) + .Test(param.kernel); +} + +TEST_P(Test, channels_div_multipass_fulltile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels( + {param.channel_tile * 2, param.channel_tile * 8, param.channel_tile}) + .rows({param.row_tile, param.row_tile * 4, param.row_tile}) + .Test(param.kernel); +} + +TEST_P(Test, channels_div_multipass_fulltile_with_input_stride) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels( + {param.channel_tile * 2, param.channel_tile * 8, param.channel_tile}) + .rows({param.row_tile, param.row_tile * 4, param.row_tile}) + .input_stride1(param.channel_tile * 8 + 5) + .Test(param.kernel); +} + +TEST_P(Test, channels_lt_2pass_fulltile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester().channels(param.row_tile * 2).rows(param.row_tile).Test(param.kernel); +} + +TEST_P(Test, channels_lt_2pass_subtile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels({1, param.channel_tile}) + .rows({1, param.row_tile * 2}) + .Test(param.kernel); +} + +TEST_P(Test, channels_lt_multipass_fulltile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels({1, param.channel_tile}) + .rows({param.row_tile, param.row_tile * 4, param.row_tile}) + .Test(param.kernel); +} + +TEST_P(Test, channels_lt_multipass_fulltile_with_input_stride) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels({1, param.channel_tile}) + .rows({param.row_tile, param.row_tile * 4, param.row_tile}) + .input_stride1(param.channel_tile + 5) + .Test(param.kernel); +} + +TEST_P(Test, channels_gt_2pass_fulltile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .rows(param.row_tile * 2) + .channels({param.channel_tile + 1, param.channel_tile * 2}) + .Test(param.kernel); +} + +TEST_P(Test, channels_gt_2pass_subtile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels({param.channel_tile + 1, param.channel_tile * 2}) + .rows({1, param.row_tile * 2}) + .Test(param.kernel); +} + +TEST_P(Test, channels_gt_multipass_fulltile) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels({param.channel_tile + 1, param.channel_tile * 2}) + .rows({param.row_tile, param.row_tile * 4, param.row_tile}) + .Test(param.kernel); +} + +TEST_P(Test, channels_gt_multipass_fulltile_with_input_stride) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels({param.channel_tile + 1, param.channel_tile * 2}) + .rows({param.row_tile, param.row_tile * 4, param.row_tile}) + .input_stride1(param.channel_tile * 2 + 5) + .Test(param.kernel); +} + +TEST_P(Test, overflow_accumulator) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester() + .channels({param.channel_tile + 1, param.channel_tile * 2}) + .rows(512) + .Test(param.kernel); +} diff --git a/test/rsum2.cc b/test/rsum2.cc new file mode 100644 index 00000000000..3c64bf88d0d --- /dev/null +++ b/test/rsum2.cc @@ -0,0 +1,180 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include +#include "src/xnnpack/buffer.h" +#include "src/xnnpack/common.h" // IWYU pragma: keep +#include "src/xnnpack/hardware-config.h" // IWYU pragma: keep +#include "src/xnnpack/isa-checks.h" +#include "src/xnnpack/microfnptr.h" +#include "src/xnnpack/microparams-init.h" // IWYU pragma: keep +#include "src/xnnpack/reduce.h" // IWYU pragma: keep +#include "test/replicable_random_device.h" + +struct Kernel; + +class Tester { + public: + Tester& batch_size(size_t size) { + batch_size_ = size; + return *this; + } + size_t batch_size() const { return batch_size_; } + + Tester& scale(float scale) { + scale_ = scale; + return *this; + } + float scale() const { return scale_; } + + // Type deduction helper. + template + using UKernelFn = void (*)(size_t, const Input*, Output*, const Params*); + + template + void Test(UKernelFn ukernel, + InitParams init_params) const { + xnnpack::ReplicableRandomDevice rng; + xnnpack::Buffer input(batch_size(), xnnpack::XnnExtraBytes); + + const float max_abs_value = 10.0f; + xnnpack::DatatypeGenerator input_gen(-max_abs_value, max_abs_value); + std::generate_n(input.data(), input.size(), + [&]() { return input_gen(rng); }); + + xnnpack::DatatypeGenerator output_gen(-max_abs_value, + max_abs_value); + Output output = output_gen(rng); + + Output expected = 0.0f; + for (size_t i = 0; i < batch_size(); ++i) { + const Output x = input[i]; + expected += x * x; + } + + // Note accumulation with output happens after scale. + const float scale = init_params ? this->scale() : 1.0f; + expected *= scale; + expected += output; + + Params params; + if (init_params) { + init_params(¶ms, scale); + } + + ukernel(batch_size() * sizeof(Input), input.data(), &output, ¶ms); + + const float tolerance = batch_size() * max_abs_value * max_abs_value * + scale * 2.0f * + xnnpack::NumericLimits::epsilon(); + ASSERT_NEAR(expected, output, tolerance); + } + + void Test(const Kernel& kernel) const; + + private: + size_t batch_size_; + float scale_ = 1.0f; +}; + +struct Kernel { + explicit Kernel(xnn_f32_rsum2_ukernel_fn fn, + xnn_init_f32_scale_params_fn init_params) { + dispatch = [=](const Tester& tester) { tester.Test(fn, init_params); }; + } + explicit Kernel(xnn_f16_f32acc_rsum2_ukernel_fn fn, + xnn_init_f16_f32acc_scale_params_fn init_params) { + dispatch = [=](const Tester& tester) { tester.Test(fn, init_params); }; + } + std::function dispatch; +}; + +void Tester::Test(const Kernel& kernel) const { kernel.dispatch(*this); } + +struct KernelInfo { + const char* name; + uint64_t arch_flags; + Kernel kernel; + size_t batch_tile; + bool vector_tile; + size_t elem_size; +}; + +KernelInfo kernels[] = { +#define XNN_UKERNEL(arch_flags, ukernel, batch_tile, vector_tile, datatype_in, \ + datatype_out, params_type, init_params) \ + {#ukernel, arch_flags, Kernel{ukernel, init_params}, \ + batch_tile, vector_tile, sizeof(datatype_in)}, +#include "src/f16-f32acc-rsum2/f16-f32acc-rsum2.inc" +#include "src/f32-rsum2/f32-rsum2.inc" +#undef XNN_UKERNEL +}; + +class Test : public testing::TestWithParam {}; + +INSTANTIATE_TEST_SUITE_P( + rsum2, Test, testing::ValuesIn(kernels), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST_P(Test, batch_eq) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + const size_t batch_tile = param.batch_tile * get_batch_scale(param.elem_size); + Tester() + .batch_size(batch_tile * get_batch_scale(param.elem_size)) + .Test(param.kernel); +} + +TEST_P(Test, batch_div) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + const size_t batch_tile = param.batch_tile * get_batch_scale(param.elem_size); + for (size_t batch_size = batch_tile; + batch_size < batch_tile * 5 && !HasFatalFailure(); + batch_size += batch_tile) { + Tester().batch_size(batch_tile).Test(param.kernel); + } +} + +TEST_P(Test, batch_lt) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + const size_t batch_tile = param.batch_tile * get_batch_scale(param.elem_size); + for (size_t batch_size = 1; batch_size < batch_tile && !HasFatalFailure(); + batch_size++) { + Tester().batch_size(batch_size).Test(param.kernel); + } +} + +TEST_P(Test, batch_gt) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + const size_t batch_tile = param.batch_tile * get_batch_scale(param.elem_size); + for (size_t batch_size = batch_tile + 1; + batch_size < batch_tile * 2 && !HasFatalFailure(); batch_size++) { + Tester().batch_size(batch_size).Test(param.kernel); + } +} + +TEST_P(Test, scale) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + for (float scale = 0.3f; scale < 5.0f && !HasFatalFailure(); scale *= 3.0f) { + Tester().batch_size(2).scale(scale).Test(param.kernel); + } +} + +TEST_P(Test, overflow_accumulator) { + const KernelInfo& param = GetParam(); + TEST_REQUIRES_ARCH_FLAGS(param.arch_flags); + Tester().batch_size(128).Test(param.kernel); +} diff --git a/test/subgraph/BUILD b/test/subgraph/BUILD index 672859b3488..cb69d7111fb 100644 --- a/test/subgraph/BUILD +++ b/test/subgraph/BUILD @@ -117,11 +117,12 @@ xnnpack_unit_test( ], deps = SUBGRAPH_TEST_DEPS, ) for operator in [ - "copy", "broadcast", + "copy", + "depth_to_space_2d", + "normalize", "softmax", "space_to_depth_2d", - "depth_to_space_2d", "static_constant_pad", "static_expand_dims", "static_reshape", @@ -333,11 +334,11 @@ xnnpack_unit_test( "//:allocation_type", "//:allocator", "//:buffer", + "//:common", "//:math", "//:node_type", "//:operator_h", "//:params", - "//:subgraph", "//:subgraph_h", "//:xnnpack_h", "//test:replicable_random_device", diff --git a/test/subgraph/CMakeLists.txt b/test/subgraph/CMakeLists.txt index 3676f63e958..c612776f731 100644 --- a/test/subgraph/CMakeLists.txt +++ b/test/subgraph/CMakeLists.txt @@ -51,6 +51,7 @@ IF(XNNPACK_BUILD_LIBRARY) even-split fully-connected max-pooling-2d + normalize softmax space-to-depth-2d split-fuse diff --git a/test/subgraph/normalize.cc b/test/subgraph/normalize.cc new file mode 100644 index 00000000000..db96b918a02 --- /dev/null +++ b/test/subgraph/normalize.cc @@ -0,0 +1,343 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "include/xnnpack.h" +#include "src/xnnpack/buffer.h" +#include "src/xnnpack/common.h" +#include "src/xnnpack/datatype.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/node-type.h" +#include "src/xnnpack/operator-utils.h" +#include "test/replicable_random_device.h" +#include "test/subgraph/subgraph-tester.h" + +static const float kMaxAbsInput = 10.0f; +static const float kRMSNormEpsilon = 1.0e-6f; + +namespace xnnpack { + +template +Tensor normalize(enum xnn_norm_type norm_type, Tensor x, + Tensor scale = {}) { + Tensor y(x.extents()); + std::vector batch_dims = x.extents(); + size_t channels = x.extents().back(); + batch_dims.pop_back(); + for (std::vector i : EnumerateIndices(batch_dims)) { + i.push_back(0); + const T* x_i = &x(i); + T* y_i = &y(i); + double sum_of_squares = 0.0; + for (size_t c = 0; c < channels; c++) { + const double x_i_c = x_i[c]; + sum_of_squares += x_i_c * x_i_c; + } + sum_of_squares = std::max(sum_of_squares, 0.0); + double rms_scale; + switch (norm_type) { + case xnn_norm_l2: + rms_scale = 1.0 / std::sqrt(kRMSNormEpsilon + sum_of_squares); + break; + case xnn_norm_rms: + rms_scale = + 1.0 / std::sqrt(kRMSNormEpsilon + sum_of_squares / channels); + break; + default: + XNN_UNREACHABLE; + } + for (size_t c = 0; c < channels; c++) { + y_i[c] = x_i[c] * rms_scale * + (scale.empty() ? 1.0f : static_cast(scale[c])); + } + } + return y; +} + +template +void TestImpl(size_t rank, bool use_scale, enum xnn_norm_type norm_type) { + ReplicableRandomDevice rng; + + ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); + + // Define subgraph + SubgraphTester subgraph(use_scale ? 3 : 2); + const uint32_t input_id = 0; + const uint32_t output_id = 1; + const uint32_t scale_id = use_scale ? 2 : XNN_INVALID_VALUE_ID; + subgraph.AddInputTensor(rank, xnn_datatype_of(), input_id) + .AddOutputTensor(rank, xnn_datatype_of(), output_id); + if (use_scale) { + subgraph.AddInputTensor(1, xnn_datatype_of(), scale_id); + } + subgraph.AddRMSNorm(input_id, scale_id, output_id, norm_type, + kRMSNormEpsilon); + xnn_status status = subgraph.CreateRuntime(); + if (status == xnn_status_unsupported_hardware) { + GTEST_SKIP(); + return; + } + + for (auto _ : FuzzTest(std::chrono::milliseconds(500))) { + std::vector shape = random_shape(rng, rank); + + // Generate the input. + Tensor input(shape, xnnpack::XnnExtraBytes); + DatatypeGenerator generator(-kMaxAbsInput, kMaxAbsInput); + input.generate([&]() { return generator(rng); }); + + // Generate and populate the scale Tensor, if requested. + Tensor scale; + if (use_scale) { + scale = Tensor({shape.back()}, xnnpack::XnnExtraBytes); + DatatypeGenerator scale_generator(-1.0f, 1.0f); + scale.generate([&]() { return scale_generator(rng); }); + subgraph.ReshapeExternalTensor(scale.shape(), scale.base(), scale_id); + } + + Tensor expected = normalize(norm_type, input, scale); + + // Check reshaped shape is correct + subgraph.ReshapeExternalTensor(shape, input.base(), input_id) + .ReshapeRuntime(); + ASSERT_EQ(subgraph.GetExternalTensorShape(output_id), expected.extents()); + + // Run subgraph + // RMSNorm reads from the output assuming XNN_EXTRA_BYTES exist. + Tensor output(expected.extents(), xnnpack::XnnExtraBytes); + subgraph.SetupExternalTensor(output.base(), output_id) + .SetupRuntime() + .InvokeRuntime(); + + // Verify results. + const float tolerance = NumericLimits::epsilon() * kMaxAbsInput * + kMaxAbsInput * shape.back() * 2.0; + ASSERT_THAT(output.template cast(), + testing::Pointwise(testing::NanSensitiveFloatNear(tolerance), + expected.template cast())); + } +} + +template +void TestSubgraphRewrite(size_t rank, bool use_scale, + enum xnn_norm_type norm_type) { + ReplicableRandomDevice rng; + + ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); + + for (auto _ : FuzzTest(std::chrono::milliseconds(500))) { + // Define subgraph + SubgraphTester subgraph(2); + const uint32_t input_id = 0; + const uint32_t output_id = 1; + uint32_t scale_id = XNN_INVALID_VALUE_ID; + std::vector shape = random_shape(rng, rank); + + // Generate the input. + Tensor input(shape, xnnpack::XnnExtraBytes); + DatatypeGenerator generator(-kMaxAbsInput, kMaxAbsInput); + input.generate([&]() { return generator(rng); }); + + subgraph.AddInputTensor(shape, xnn_datatype_of(), input_id); + subgraph.AddOutputTensor(shape, xnn_datatype_of(), output_id); + + // Generate and populate the scale Tensor, if requested. + Tensor scale; + if (use_scale) { + std::vector scale_shape = shape; + std::fill(scale_shape.begin(), scale_shape.end() - 1, 1); + scale = Tensor(scale_shape, xnnpack::XnnExtraBytes); + DatatypeGenerator scale_generator(-1.0f, 1.0f); + scale.generate([&]() { return scale_generator(rng); }); + subgraph.AddInternalStaticTensor(scale_shape, xnn_datatype_of(), + &scale_id, scale.data()); + } + + // Generate the RMS/L2-Norm nodes, randomly swapping inputs where + // associative. + + // b = mul(a, a). + uint32_t squared_id = XNN_INVALID_VALUE_ID; + subgraph.AddInternalDynamicTensor(shape, xnn_datatype_of(), &squared_id, + /*flags=*/0); + subgraph.AddMultiply(input_id, input_id, squared_id); + + // c = reduce_sum(b, axis=-1). + uint32_t sum_of_squares_id = XNN_INVALID_VALUE_ID; + std::vector reduced_shape = shape; + reduced_shape.back() = 1; + subgraph.AddInternalDynamicTensor(reduced_shape, xnn_datatype_of(), + &sum_of_squares_id, + /*flags=*/0); + subgraph.AddReduce(xnn_reduce_sum, {static_cast(rank) - 1}, + squared_id, sum_of_squares_id, XNN_FLAG_KEEP_DIMS); + + uint32_t scaled_sum_id = sum_of_squares_id; + const T inv_n = 1.0 / shape.back(); + switch (norm_type) { + case xnn_norm_rms: { + // d = mul(c, inv_n). + uint32_t inv_n_id = XNN_INVALID_VALUE_ID; + subgraph.AddInternalDynamicTensor(reduced_shape, xnn_datatype_of(), + &scaled_sum_id, + /*flags=*/0); + subgraph.AddInternalStaticTensor(/*shape=*/{1}, xnn_datatype_of(), + &inv_n_id, &inv_n); + if (rng() % 2) { + subgraph.AddMultiply(sum_of_squares_id, inv_n_id, scaled_sum_id); + } else { + subgraph.AddMultiply(inv_n_id, sum_of_squares_id, scaled_sum_id); + } + } break; + case xnn_norm_l2: + break; + default: + XNN_UNREACHABLE; + } + + // Optionally e = add(d, eps). + uint32_t shifted_scaled_sum_id = scaled_sum_id; + T epsilon = 0.0; + if (rng() % 2) { + uint32_t epsilon_id = XNN_INVALID_VALUE_ID; + epsilon = kRMSNormEpsilon; + subgraph.AddInternalDynamicTensor(reduced_shape, xnn_datatype_of(), + &shifted_scaled_sum_id, + /*flags=*/0); + subgraph.AddInternalStaticTensor(/*shape=*/{1}, xnn_datatype_of(), + &epsilon_id, &epsilon); + if (rng() % 2) { + subgraph.AddAddition(scaled_sum_id, epsilon_id, shifted_scaled_sum_id); + } else { + subgraph.AddAddition(epsilon_id, scaled_sum_id, shifted_scaled_sum_id); + } + } + + // f = sqrt(shifted_rms_id). + uint32_t sqrt_shifted_scaled_sum_id = XNN_INVALID_VALUE_ID; + subgraph.AddInternalDynamicTensor(reduced_shape, xnn_datatype_of(), + &sqrt_shifted_scaled_sum_id, + /*flags=*/0); + subgraph.AddUnary(xnn_unary_square_root, /*params=*/nullptr, + shifted_scaled_sum_id, sqrt_shifted_scaled_sum_id); + + // g = div(a, f). + if (use_scale) { + uint32_t normalized_id = XNN_INVALID_VALUE_ID; + subgraph.AddInternalDynamicTensor(shape, xnn_datatype_of(), + &normalized_id, + /*flags=*/0); + subgraph.AddDivide(input_id, sqrt_shifted_scaled_sum_id, normalized_id); + if (rng() % 2) { + subgraph.AddMultiply(normalized_id, scale_id, output_id); + } else { + subgraph.AddMultiply(scale_id, normalized_id, output_id); + } + } else { + subgraph.AddBinary(xnn_binary_divide, /*params=*/nullptr, input_id, + sqrt_shifted_scaled_sum_id, output_id); + } + + // Set up the input/output tensors. + Tensor expected(shape, xnnpack::XnnExtraBytes); + subgraph.SetupExternalTensor(input.base(), input_id); + subgraph.SetupExternalTensor(expected.base(), output_id); + + // Evaluate once with `XNN_FLAG_SLOW_CONSISTENT_ARITHMETIC` enabled to + // prevent the subgraph replacement. + xnn_status status = subgraph.CreateRuntime( + /*threadpool=*/nullptr, XNN_FLAG_SLOW_CONSISTENT_ARITHMETIC); + if (status == xnn_status_unsupported_hardware) { + GTEST_SKIP(); + return; + } + ASSERT_GT(subgraph.NumNodes(), 1); + + // Run the subgraph. + subgraph.ReshapeRuntime(); + subgraph.SetupRuntime(); + subgraph.InvokeRuntime(); + + // Create the runtime and evaluate again and check that the subgraph was + // replaced. + subgraph.Optimize(); + ASSERT_EQ(subgraph.NumNodes(), 1); + ASSERT_EQ(subgraph.Node(0)->type, xnn_node_type_normalize); + + // Run the subgraph. + Tensor output(shape, xnnpack::XnnExtraBytes); + subgraph.SetupExternalTensor(output.base(), output_id); + + // Run the subgraph. + subgraph.ReshapeRuntime(); + subgraph.SetupRuntime(); + subgraph.InvokeRuntime(); + + // Verify results. + const float tolerance = NumericLimits::epsilon() * kMaxAbsInput * + kMaxAbsInput * shape.back(); + ASSERT_THAT(output.template cast(), + testing::Pointwise(testing::NanSensitiveFloatNear(tolerance), + expected.template cast())); + } +} + +template +class NormalizeTest : public ::testing::TestWithParam< + std::tuple> {}; +std::string NormalizeTestName( + const testing::TestParamInfo>& + info) { + auto& params = info.param; + char buff[100]; + sprintf(buff, "%s_%zu%s", xnn_norm_type_to_string(std::get<2>(params)), + std::get<0>(params), std::get<1>(params) ? "_scaling" : ""); + return std::string(buff); +} +using NormalizeTestF16 = NormalizeTest; +using NormalizeTestF32 = NormalizeTest; + +TEST_P(NormalizeTestF16, test) { + TestImpl(/*rank=*/std::get<0>(GetParam()), + /*use_scale=*/std::get<1>(GetParam()), + /*norm_type=*/std::get<2>(GetParam())); +} +TEST_P(NormalizeTestF32, test) { + TestImpl(/*rank=*/std::get<0>(GetParam()), + /*use_scale=*/std::get<1>(GetParam()), + /*norm_type=*/std::get<2>(GetParam())); +} +TEST_P(NormalizeTestF16, subgraph_rewrite) { + TestSubgraphRewrite(/*rank=*/std::get<0>(GetParam()), + /*use_scale=*/std::get<1>(GetParam()), + /*norm_type=*/std::get<2>(GetParam())); +} +TEST_P(NormalizeTestF32, subgraph_rewrite) { + TestSubgraphRewrite(/*rank=*/std::get<0>(GetParam()), + /*use_scale=*/std::get<1>(GetParam()), + /*norm_type=*/std::get<2>(GetParam())); +} + +auto test_params = testing::Combine( + testing::Range(1, XNN_MAX_TENSOR_DIMS), testing::Bool(), + testing::Values(xnn_norm_l2, xnn_norm_rms)); +INSTANTIATE_TEST_SUITE_P(RMSNorm, NormalizeTestF16, test_params, + NormalizeTestName); +INSTANTIATE_TEST_SUITE_P(RMSNorm, NormalizeTestF32, test_params, + NormalizeTestName); + +} // namespace xnnpack diff --git a/test/subgraph/softmax.cc b/test/subgraph/softmax.cc index a7b171d973b..2921da42909 100644 --- a/test/subgraph/softmax.cc +++ b/test/subgraph/softmax.cc @@ -21,7 +21,7 @@ namespace xnnpack { template -Tensor softmax(Tensor x) { +Tensor rms_norm(Tensor x) { Tensor y(x.extents()); std::vector batch_dims = x.extents(); size_t channels = x.extents().back(); @@ -68,7 +68,7 @@ void TestImpl(size_t rank) { DatatypeGenerator generator(-20.0f, 20.0f); input.generate([&]() { return generator(rng); }); - Tensor expected = softmax(input); + Tensor expected = rms_norm(input); // Check reshaped shape is correct subgraph.ReshapeExternalTensor(shape, input.base(), 0).ReshapeRuntime(); diff --git a/test/subgraph/static-reduce.cc b/test/subgraph/static-reduce.cc index 9a7b2b6d363..8ace7b4dca3 100644 --- a/test/subgraph/static-reduce.cc +++ b/test/subgraph/static-reduce.cc @@ -42,6 +42,9 @@ struct Param { case xnn_reduce_sum: sstr << "sum"; break; + case xnn_reduce_sum_squared: + sstr << "sum_squared"; + break; case xnn_reduce_max: sstr << "max"; break; @@ -103,6 +106,8 @@ std::function get_reference_op(xnn_reduce_operator op) { case xnn_reduce_sum: case xnn_reduce_mean: return [](float& output, float input) { output += input; }; + case xnn_reduce_sum_squared: + return [](float& output, float input) { output += input * input; }; case xnn_reduce_min: return [](float& output, float input) { output = std::min(output, input); }; @@ -256,4 +261,12 @@ INSTANTIATE_TEST_SUITE_P(Reduce, ReduceF16, params, INSTANTIATE_TEST_SUITE_P(Reduce, ReduceF32, params, [](auto p) { return p.param.Name(); }); +auto params2 = testing::ConvertGenerator(Combine( + Values(xnn_reduce_sum_squared), + Bool(), Bool(), Range(0, XNN_MAX_TENSOR_DIMS))); +INSTANTIATE_TEST_SUITE_P(Reduce2, ReduceF16, params2, + [](auto p) { return p.param.Name(); }); +INSTANTIATE_TEST_SUITE_P(Reduce2, ReduceF32, params2, + [](auto p) { return p.param.Name(); }); + } // namespace xnnpack diff --git a/test/subgraph/subgraph-fp16.cc b/test/subgraph/subgraph-fp16.cc index d75a2af7d0a..307a9730d53 100644 --- a/test/subgraph/subgraph-fp16.cc +++ b/test/subgraph/subgraph-fp16.cc @@ -18,6 +18,7 @@ #include "include/xnnpack.h" #include "src/xnnpack/allocation-type.h" #include "src/xnnpack/buffer.h" +#include "src/xnnpack/common.h" #include "src/xnnpack/math.h" #include "src/xnnpack/node-type.h" #include "src/xnnpack/operator.h" @@ -1260,15 +1261,27 @@ TEST(SUBGRAPH_FP16_DUPLICATE_INPUTS, converted_only_once) { // external // output[1] - // We should have 3 nodes, the original Mul node, plus one convert node for - // each of the external input and output. + // We should have 3 nodes, the original Mul node (which is converted to + // `sqr`), plus one convert node for each of the external input and output. ASSERT_EQ(tester.NumNodes(), 3); ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert); - ASSERT_EQ(tester.Node(1)->type, xnn_node_type_binary_elementwise); + ASSERT_THAT(tester.Node(1)->type, + testing::AnyOf(xnn_node_type_unary_elementwise, + xnn_node_type_binary_elementwise)); ASSERT_EQ(tester.Node(2)->type, xnn_node_type_convert); - // Check that the inputs to the Mul node are the same value. - ASSERT_EQ(tester.Node(1)->inputs[0], tester.Node(1)->inputs[1]); + switch (tester.Node(1)->type) { + case xnn_node_type_binary_elementwise: + // Check that the inputs to the Mul node are the same value. + ASSERT_EQ(tester.Node(1)->binary_operator, xnn_binary_multiply); + ASSERT_EQ(tester.Node(1)->inputs[0], tester.Node(1)->inputs[1]); + break; + case xnn_node_type_unary_elementwise: + ASSERT_EQ(tester.Node(1)->unary_operator, xnn_unary_square); + break; + default: + XNN_UNREACHABLE; + } // Check that the output of convert is allocated in workspace. const xnn_value* convert_out = tester.Value(3); diff --git a/test/subgraph/subgraph-tester.cc b/test/subgraph/subgraph-tester.cc index edea0900211..192686b2e85 100644 --- a/test/subgraph/subgraph-tester.cc +++ b/test/subgraph/subgraph-tester.cc @@ -36,10 +36,22 @@ SubgraphTester::SubgraphTester(uint32_t external_value_ids, uint32_t flags) { subgraph_.reset(subgraph_ptr); } -SubgraphTester& SubgraphTester::AddInternalDynamicTensorF32( - const TensorShape& shape, uint32_t* id_out, uint32_t flags) { +SubgraphTester& SubgraphTester::AddInternalDynamicTensor( + const TensorShape& shape, enum xnn_datatype datatype, uint32_t* id_out, + uint32_t flags) { const xnn_status status = xnn_define_tensor_value( - subgraph_.get(), xnn_datatype_fp32, shape.Rank(), shape.Dims(), nullptr, + subgraph_.get(), datatype, shape.Rank(), shape.Dims(), nullptr, + XNN_INVALID_VALUE_ID, flags, id_out); + EXPECT_EQ(status, xnn_status_success); + + return *this; +} + +SubgraphTester& SubgraphTester::AddInternalStaticTensor( + const TensorShape& shape, enum xnn_datatype datatype, uint32_t* id_out, + const void* data, uint32_t flags) { + const xnn_status status = xnn_define_tensor_value( + subgraph_.get(), datatype, shape.Rank(), shape.Dims(), data, XNN_INVALID_VALUE_ID, flags, id_out); EXPECT_EQ(status, xnn_status_success); @@ -740,6 +752,17 @@ SubgraphTester& SubgraphTester::AddReduce( return *this; } +SubgraphTester& SubgraphTester::AddRMSNorm(uint32_t input_id, uint32_t scale_id, + uint32_t output_id, + enum xnn_norm_type norm_type, + float epsilon, uint32_t flags) { + const xnn_status status = + xnn_define_normalize(subgraph_.get(), norm_type, input_id, scale_id, + output_id, epsilon, flags); + EXPECT_EQ(status, xnn_status_success); + return *this; +} + SubgraphTester& SubgraphTester::AddSoftmax(uint32_t input_id, uint32_t output_id, uint32_t flags) { const xnn_status status = diff --git a/test/subgraph/subgraph-tester.h b/test/subgraph/subgraph-tester.h index 6cabcf98609..2341ff62ca0 100644 --- a/test/subgraph/subgraph-tester.h +++ b/test/subgraph/subgraph-tester.h @@ -120,9 +120,21 @@ class SubgraphTester { explicit SubgraphTester(uint32_t external_value_ids, uint32_t flags = xnn_test_runtime_flags()); + SubgraphTester& AddInternalStaticTensor(const TensorShape& shape, + enum xnn_datatype datatype, + uint32_t* id_out, const void* data, + uint32_t flags = 0); + + SubgraphTester& AddInternalDynamicTensor(const TensorShape& shape, + enum xnn_datatype datatype, + uint32_t* id_out, + uint32_t flags = 0); + SubgraphTester& AddInternalDynamicTensorF32(const TensorShape& shape, uint32_t* id_out, - uint32_t flags = 0); + uint32_t flags = 0) { + return AddInternalDynamicTensor(shape, xnn_datatype_fp32, id_out, flags); + } SubgraphTester& AddInternalDynamicallyQuantizedTensor( const TensorShape& shape, xnn_datatype datatype, size_t num_nonbatch_dims, @@ -497,6 +509,10 @@ class SubgraphTester { uint32_t input_id, uint32_t output_id, uint32_t flags = 0); + SubgraphTester& AddRMSNorm(uint32_t input_id, uint32_t scale_id, + uint32_t output_id, enum xnn_norm_type norm_type, + float epsilon, uint32_t flags = 0); + SubgraphTester& AddSoftmax(uint32_t input_id, uint32_t output_id, uint32_t flags = 0);