diff --git a/cmake/gen/avx_microkernels.cmake b/cmake/gen/avx_microkernels.cmake index 9961ad9ec1e..f1930ae09f8 100644 --- a/cmake/gen/avx_microkernels.cmake +++ b/cmake/gen/avx_microkernels.cmake @@ -25,6 +25,7 @@ SET(PROD_AVX_MICROKERNEL_SRCS src/f32-igemm/gen/f32-igemm-1x16-minmax-avx-broadcast.c src/f32-igemm/gen/f32-igemm-5x8-minmax-avx-broadcast.c src/f32-igemm/gen/f32-igemm-5x16-minmax-avx-broadcast.c + src/f32-maxpool/gen/f32-maxpool-9p-minmax-avx-u8.c src/f32-qc4w-gemm/gen/f32-qc4w-gemm-1x16-minmax-avx-broadcast.c src/f32-qc4w-gemm/gen/f32-qc4w-gemm-3x16-minmax-avx-broadcast.c src/f32-qc8w-gemm/gen/f32-qc8w-gemm-1x16-minmax-avx-broadcast.c diff --git a/gen/avx_microkernels.bzl b/gen/avx_microkernels.bzl index 399737f78e3..68f36498dc9 100644 --- a/gen/avx_microkernels.bzl +++ b/gen/avx_microkernels.bzl @@ -21,6 +21,7 @@ PROD_AVX_MICROKERNEL_SRCS = [ "src/f32-igemm/gen/f32-igemm-1x16-minmax-avx-broadcast.c", "src/f32-igemm/gen/f32-igemm-5x8-minmax-avx-broadcast.c", "src/f32-igemm/gen/f32-igemm-5x16-minmax-avx-broadcast.c", + "src/f32-maxpool/gen/f32-maxpool-9p-minmax-avx-u8.c", "src/f32-qc4w-gemm/gen/f32-qc4w-gemm-1x16-minmax-avx-broadcast.c", "src/f32-qc4w-gemm/gen/f32-qc4w-gemm-3x16-minmax-avx-broadcast.c", "src/f32-qc8w-gemm/gen/f32-qc8w-gemm-1x16-minmax-avx-broadcast.c", diff --git a/scripts/generate-f32-maxpool.sh b/scripts/generate-f32-maxpool.sh index 2a142c77987..4fbb717fa83 100755 --- a/scripts/generate-f32-maxpool.sh +++ b/scripts/generate-f32-maxpool.sh @@ -7,6 +7,7 @@ ##################################### SIMD ##################################### tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=scalar -D SIMD_SIZE=1 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-scalar-u1.c & tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=sse2 -D SIMD_SIZE=4 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-sse2-u4.c & +tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=avx -D SIMD_SIZE=8 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-avx-u8.c & tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=wasmsimd -D SIMD_SIZE=4 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-wasmsimd-u4.c & tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=neon -D SIMD_SIZE=4 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-neon-u4.c & tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=hvx -D SIMD_SIZE=32 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-hvx-u32.c & diff --git a/src/configs/maxpool-config.c b/src/configs/maxpool-config.c index 4ff7d8fe0eb..b304ef7fc3d 100644 --- a/src/configs/maxpool-config.c +++ b/src/configs/maxpool-config.c @@ -74,7 +74,13 @@ static void init_f32_maxpool_config(void) { f32_maxpool_config.ukernel = XNN_INIT_MAXPOOL_UKERNEL(xnn_f32_maxpool_minmax_ukernel_9p__neon_u4); f32_maxpool_config.init.f32 = xnn_init_f32_minmax_scalar_params; #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 - f32_maxpool_config.ukernel = XNN_INIT_MAXPOOL_UKERNEL(xnn_f32_maxpool_minmax_ukernel_9p__sse2_u4); + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + if ((hardware_config->arch_flags & xnn_arch_x86_avx)) { + f32_maxpool_config.ukernel = XNN_INIT_MAXPOOL_UKERNEL(xnn_f32_maxpool_minmax_ukernel_9p__avx_u8); + } else { + f32_maxpool_config.ukernel = XNN_INIT_MAXPOOL_UKERNEL(xnn_f32_maxpool_minmax_ukernel_9p__sse2_u4); + } f32_maxpool_config.init.f32 = xnn_init_f32_minmax_scalar_params; #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD f32_maxpool_config.ukernel = XNN_INIT_MAXPOOL_UKERNEL(xnn_f32_maxpool_minmax_ukernel_9p__wasmsimd_u4); diff --git a/src/f32-maxpool/f32-maxpool-minmax.inc b/src/f32-maxpool/f32-maxpool-minmax.inc index f6e50a29933..b4e3a246206 100644 --- a/src/f32-maxpool/f32-maxpool-minmax.inc +++ b/src/f32-maxpool/f32-maxpool-minmax.inc @@ -8,6 +8,7 @@ #if XNN_ARCH_X86 || XNN_ARCH_X86_64 XNN_UKERNEL(xnn_arch_none, xnn_f32_maxpool_minmax_ukernel_9p__sse2_u4, 4, 9, float, struct xnn_f32_minmax_params, xnn_init_f32_minmax_scalar_params) +XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_maxpool_minmax_ukernel_9p__avx_u8, 8, 9, float, struct xnn_f32_minmax_params, xnn_init_f32_minmax_scalar_params) #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 diff --git a/src/f32-maxpool/gen/f32-maxpool-9p-minmax-avx-u8.c b/src/f32-maxpool/gen/f32-maxpool-9p-minmax-avx-u8.c new file mode 100644 index 00000000000..7c291d0a895 --- /dev/null +++ b/src/f32-maxpool/gen/f32-maxpool-9p-minmax-avx-u8.c @@ -0,0 +1,216 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/f32-maxpool/maxpool.c.in +// Generator: tools/xngen +// +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + + +// On some architectures, we have max(u8, u8) but not max(s8, s8). We can emulate max(s8, s8) on these architectures by +// xoring with the sign bit mask. +#define xnn_load_impl(x) xnn_loadu_f32(x) +#define xnn_load_tail_impl(x, c) xnn_load_tail_f32(x, c) +#define xnn_load_tail_safe_impl(x, c) xnn_load_tail_safe_f32(x, c) +#define xnn_pre_store_impl(x) x + +#include "src/xnnpack/common.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" +#include "src/xnnpack/simd/f32-avx.h" + +void xnn_f32_maxpool_minmax_ukernel_9p__avx_u8( + size_t output_pixels, + size_t kernel_elements, + size_t channels, + const float** input, + size_t input_offset, + size_t input_pixel_stride, + float* output, + size_t input_increment, + size_t output_increment, + const struct xnn_f32_minmax_params* restrict params) +{ + assert(output_pixels != 0); + assert(channels != 0); + + const xnn_simd_f32_t vmin = xnn_set1_f32(params->scalar.min); + const xnn_simd_f32_t vmax = xnn_set1_f32(params->scalar.max); + XNN_FORCE_REALIZATION(vmin); + XNN_FORCE_REALIZATION(vmax); + + do { + const float** i = (const float**) input; + + // First pass: load the inputs, store the max pool in the output. + const float* i0 = *i++; + const float* i1 = 1 < kernel_elements ? *i++ : i0; + const float* i2 = 2 < kernel_elements ? *i++ : i0; + const float* i3 = 3 < kernel_elements ? *i++ : i0; + const float* i4 = 4 < kernel_elements ? *i++ : i0; + const float* i5 = 5 < kernel_elements ? *i++ : i0; + const float* i6 = 6 < kernel_elements ? *i++ : i0; + const float* i7 = 7 < kernel_elements ? *i++ : i0; + const float* i8 = 8 < kernel_elements ? *i++ : i0; + i0 = (const float*) ((uintptr_t) i0 + input_offset); + i1 = (const float*) ((uintptr_t) i1 + input_offset); + i2 = (const float*) ((uintptr_t) i2 + input_offset); + i3 = (const float*) ((uintptr_t) i3 + input_offset); + i4 = (const float*) ((uintptr_t) i4 + input_offset); + i5 = (const float*) ((uintptr_t) i5 + input_offset); + i6 = (const float*) ((uintptr_t) i6 + input_offset); + i7 = (const float*) ((uintptr_t) i7 + input_offset); + i8 = (const float*) ((uintptr_t) i8 + input_offset); + + float* o = (float*) output; + size_t c = channels; + for (; c >= 8; c -= 8) { + const xnn_simd_f32_t vi0 = xnn_load_impl(i0); i0 += 8; + const xnn_simd_f32_t vi1 = xnn_load_impl(i1); i1 += 8; + const xnn_simd_f32_t vi2 = xnn_load_impl(i2); i2 += 8; + const xnn_simd_f32_t vi3 = xnn_load_impl(i3); i3 += 8; + const xnn_simd_f32_t vi4 = xnn_load_impl(i4); i4 += 8; + const xnn_simd_f32_t vi5 = xnn_load_impl(i5); i5 += 8; + const xnn_simd_f32_t vi6 = xnn_load_impl(i6); i6 += 8; + const xnn_simd_f32_t vi7 = xnn_load_impl(i7); i7 += 8; + const xnn_simd_f32_t vi8 = xnn_load_impl(i8); i8 += 8; + + const xnn_simd_f32_t vmax018 = xnn_max_f32(xnn_max_f32(vi0, vi1), vi8); + const xnn_simd_f32_t vmax23 = xnn_max_f32(vi2, vi3); + const xnn_simd_f32_t vmax45 = xnn_max_f32(vi4, vi5); + const xnn_simd_f32_t vmax67 = xnn_max_f32(vi6, vi7); + + const xnn_simd_f32_t vmax2345 = xnn_max_f32(vmax23, vmax45); + const xnn_simd_f32_t vmax01678 = xnn_max_f32(vmax018, vmax67); + xnn_simd_f32_t vacc = xnn_max_f32(vmax2345, vmax01678); + + vacc = xnn_max_f32(vacc, vmin); + vacc = xnn_min_f32(vacc, vmax); + + vacc = xnn_pre_store_impl(vacc); + + xnn_storeu_f32(o, vacc); o += 8; + } + if (c > 0) { + const xnn_simd_f32_t vi0 = xnn_load_tail_impl(i0, c); + const xnn_simd_f32_t vi1 = xnn_load_tail_impl(i1, c); + const xnn_simd_f32_t vi2 = xnn_load_tail_impl(i2, c); + const xnn_simd_f32_t vi3 = xnn_load_tail_impl(i3, c); + const xnn_simd_f32_t vi4 = xnn_load_tail_impl(i4, c); + const xnn_simd_f32_t vi5 = xnn_load_tail_impl(i5, c); + const xnn_simd_f32_t vi6 = xnn_load_tail_impl(i6, c); + const xnn_simd_f32_t vi7 = xnn_load_tail_impl(i7, c); + const xnn_simd_f32_t vi8 = xnn_load_tail_impl(i8, c); + + const xnn_simd_f32_t vmax018 = xnn_max_f32(xnn_max_f32(vi0, vi1), vi8); + const xnn_simd_f32_t vmax23 = xnn_max_f32(vi2, vi3); + const xnn_simd_f32_t vmax45 = xnn_max_f32(vi4, vi5); + const xnn_simd_f32_t vmax67 = xnn_max_f32(vi6, vi7); + + const xnn_simd_f32_t vmax2345 = xnn_max_f32(vmax23, vmax45); + const xnn_simd_f32_t vmax01678 = xnn_max_f32(vmax018, vmax67); + xnn_simd_f32_t vacc = xnn_max_f32(vmax2345, vmax01678); + + vacc = xnn_max_f32(vacc, vmin); + vacc = xnn_min_f32(vacc, vmax); + + vacc = xnn_pre_store_impl(vacc); + + xnn_store_tail_f32(o, vacc, c); o += c; + } + + // Passes 1 - n: Max more inputs to the output. + o = (float*) output; + for (ptrdiff_t k = (ptrdiff_t) kernel_elements - 9; k > 0; k -= 9) { + const float* i0 = *i++; + const float* i1 = 1 < k ? *i++ : i0; + const float* i2 = 2 < k ? *i++ : i0; + const float* i3 = 3 < k ? *i++ : i0; + const float* i4 = 4 < k ? *i++ : i0; + const float* i5 = 5 < k ? *i++ : i0; + const float* i6 = 6 < k ? *i++ : i0; + const float* i7 = 7 < k ? *i++ : i0; + const float* i8 = 8 < k ? *i++ : i0; + i0 = (const float*) ((uintptr_t) i0 + input_offset); + i1 = (const float*) ((uintptr_t) i1 + input_offset); + i2 = (const float*) ((uintptr_t) i2 + input_offset); + i3 = (const float*) ((uintptr_t) i3 + input_offset); + i4 = (const float*) ((uintptr_t) i4 + input_offset); + i5 = (const float*) ((uintptr_t) i5 + input_offset); + i6 = (const float*) ((uintptr_t) i6 + input_offset); + i7 = (const float*) ((uintptr_t) i7 + input_offset); + i8 = (const float*) ((uintptr_t) i8 + input_offset); + + float* o = (float*) output; + size_t c = channels; + for (; c >= 8; c -= 8) { + const xnn_simd_f32_t vi0 = xnn_load_impl(i0); i0 += 8; + const xnn_simd_f32_t vi1 = xnn_load_impl(i1); i1 += 8; + const xnn_simd_f32_t vi2 = xnn_load_impl(i2); i2 += 8; + const xnn_simd_f32_t vi3 = xnn_load_impl(i3); i3 += 8; + const xnn_simd_f32_t vi4 = xnn_load_impl(i4); i4 += 8; + const xnn_simd_f32_t vi5 = xnn_load_impl(i5); i5 += 8; + const xnn_simd_f32_t vi6 = xnn_load_impl(i6); i6 += 8; + const xnn_simd_f32_t vi7 = xnn_load_impl(i7); i7 += 8; + const xnn_simd_f32_t vi8 = xnn_load_impl(i8); i8 += 8; + const xnn_simd_f32_t vprev = xnn_load_impl(o); + + const xnn_simd_f32_t vmax018 = xnn_max_f32(xnn_max_f32(vi0, vi1), vi8); + const xnn_simd_f32_t vmax23 = xnn_max_f32(vi2, vi3); + const xnn_simd_f32_t vmax45 = xnn_max_f32(vi4, vi5); + const xnn_simd_f32_t vmax67 = xnn_max_f32(vi6, vi7); + + const xnn_simd_f32_t vmax2345 = xnn_max_f32(vmax23, vmax45); + const xnn_simd_f32_t vmax01678 = xnn_max_f32(vmax018, vmax67); + const xnn_simd_f32_t vmax012345678 = xnn_max_f32(vmax2345, vmax01678); + + xnn_simd_f32_t vacc = xnn_max_f32(vprev, vmax012345678); + + vacc = xnn_min_f32(vacc, vmax); + + vacc = xnn_pre_store_impl(vacc); + + xnn_storeu_f32(o, vacc); o += 8; + } + if (c > 0) { + const xnn_simd_f32_t vi0 = xnn_load_tail_impl(i0, c); + const xnn_simd_f32_t vi1 = xnn_load_tail_impl(i1, c); + const xnn_simd_f32_t vi2 = xnn_load_tail_impl(i2, c); + const xnn_simd_f32_t vi3 = xnn_load_tail_impl(i3, c); + const xnn_simd_f32_t vi4 = xnn_load_tail_impl(i4, c); + const xnn_simd_f32_t vi5 = xnn_load_tail_impl(i5, c); + const xnn_simd_f32_t vi6 = xnn_load_tail_impl(i6, c); + const xnn_simd_f32_t vi7 = xnn_load_tail_impl(i7, c); + const xnn_simd_f32_t vi8 = xnn_load_tail_impl(i8, c); + const xnn_simd_f32_t vprev = xnn_load_tail_safe_impl(o, c); + + const xnn_simd_f32_t vmax018 = xnn_max_f32(xnn_max_f32(vi0, vi1), vi8); + const xnn_simd_f32_t vmax23 = xnn_max_f32(vi2, vi3); + const xnn_simd_f32_t vmax45 = xnn_max_f32(vi4, vi5); + const xnn_simd_f32_t vmax67 = xnn_max_f32(vi6, vi7); + + const xnn_simd_f32_t vmax2345 = xnn_max_f32(vmax23, vmax45); + const xnn_simd_f32_t vmax01678 = xnn_max_f32(vmax018, vmax67); + const xnn_simd_f32_t vmax012345678 = xnn_max_f32(vmax2345, vmax01678); + + xnn_simd_f32_t vacc = xnn_max_f32(vprev, vmax012345678); + + vacc = xnn_min_f32(vacc, vmax); + + vacc = xnn_pre_store_impl(vacc); + + xnn_store_tail_f32(o, vacc, c); + } + } + + input = (const float**) ((uintptr_t) input + input_increment); + input_offset += input_pixel_stride; + output = (float*) ((uintptr_t) output + output_increment); + } while (--output_pixels != 0); +}