Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/gen/avx_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ SET(PROD_AVX_MICROKERNEL_SRCS
src/f32-igemm/gen/f32-igemm-1x16-minmax-avx-broadcast.c
src/f32-igemm/gen/f32-igemm-5x8-minmax-avx-broadcast.c
src/f32-igemm/gen/f32-igemm-5x16-minmax-avx-broadcast.c
src/f32-maxpool/gen/f32-maxpool-9p-minmax-avx-u8.c
src/f32-qc4w-gemm/gen/f32-qc4w-gemm-1x16-minmax-avx-broadcast.c
src/f32-qc4w-gemm/gen/f32-qc4w-gemm-3x16-minmax-avx-broadcast.c
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-1x16-minmax-avx-broadcast.c
Expand Down
1 change: 1 addition & 0 deletions gen/avx_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ PROD_AVX_MICROKERNEL_SRCS = [
"src/f32-igemm/gen/f32-igemm-1x16-minmax-avx-broadcast.c",
"src/f32-igemm/gen/f32-igemm-5x8-minmax-avx-broadcast.c",
"src/f32-igemm/gen/f32-igemm-5x16-minmax-avx-broadcast.c",
"src/f32-maxpool/gen/f32-maxpool-9p-minmax-avx-u8.c",
"src/f32-qc4w-gemm/gen/f32-qc4w-gemm-1x16-minmax-avx-broadcast.c",
"src/f32-qc4w-gemm/gen/f32-qc4w-gemm-3x16-minmax-avx-broadcast.c",
"src/f32-qc8w-gemm/gen/f32-qc8w-gemm-1x16-minmax-avx-broadcast.c",
Expand Down
1 change: 1 addition & 0 deletions scripts/generate-f32-maxpool.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
##################################### SIMD #####################################
tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=scalar -D SIMD_SIZE=1 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-scalar-u1.c &
tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=sse2 -D SIMD_SIZE=4 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-sse2-u4.c &
tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=avx -D SIMD_SIZE=8 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-avx-u8.c &
tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=wasmsimd -D SIMD_SIZE=4 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-wasmsimd-u4.c &
tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=neon -D SIMD_SIZE=4 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-neon-u4.c &
tools/xngen src/f32-maxpool/maxpool.c.in -D DATATYPE=f32 -D ARCH=hvx -D SIMD_SIZE=32 -o src/f32-maxpool/gen/f32-maxpool-9p-minmax-hvx-u32.c &
Expand Down
8 changes: 7 additions & 1 deletion src/configs/maxpool-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ static void init_f32_maxpool_config(void) {
f32_maxpool_config.ukernel = XNN_INIT_MAXPOOL_UKERNEL(xnn_f32_maxpool_minmax_ukernel_9p__neon_u4);
f32_maxpool_config.init.f32 = xnn_init_f32_minmax_scalar_params;
#elif XNN_ARCH_X86 || XNN_ARCH_X86_64
f32_maxpool_config.ukernel = XNN_INIT_MAXPOOL_UKERNEL(xnn_f32_maxpool_minmax_ukernel_9p__sse2_u4);
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
if ((hardware_config->arch_flags & xnn_arch_x86_avx)) {
f32_maxpool_config.ukernel = XNN_INIT_MAXPOOL_UKERNEL(xnn_f32_maxpool_minmax_ukernel_9p__avx_u8);
} else {
f32_maxpool_config.ukernel = XNN_INIT_MAXPOOL_UKERNEL(xnn_f32_maxpool_minmax_ukernel_9p__sse2_u4);
}
f32_maxpool_config.init.f32 = xnn_init_f32_minmax_scalar_params;
#elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
f32_maxpool_config.ukernel = XNN_INIT_MAXPOOL_UKERNEL(xnn_f32_maxpool_minmax_ukernel_9p__wasmsimd_u4);
Expand Down
1 change: 1 addition & 0 deletions src/f32-maxpool/f32-maxpool-minmax.inc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#if XNN_ARCH_X86 || XNN_ARCH_X86_64
XNN_UKERNEL(xnn_arch_none, xnn_f32_maxpool_minmax_ukernel_9p__sse2_u4, 4, 9, float, struct xnn_f32_minmax_params, xnn_init_f32_minmax_scalar_params)
XNN_UKERNEL(xnn_arch_x86_avx, xnn_f32_maxpool_minmax_ukernel_9p__avx_u8, 8, 9, float, struct xnn_f32_minmax_params, xnn_init_f32_minmax_scalar_params)
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64

#if XNN_ARCH_ARM || XNN_ARCH_ARM64
Expand Down
216 changes: 216 additions & 0 deletions src/f32-maxpool/gen/f32-maxpool-9p-minmax-avx-u8.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// clang-format off
// Auto-generated file. Do not edit!
// Template: src/f32-maxpool/maxpool.c.in
// Generator: tools/xngen
//
// Copyright 2025 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <stddef.h>
#include <stdint.h>


// On some architectures, we have max(u8, u8) but not max(s8, s8). We can emulate max(s8, s8) on these architectures by
// xoring with the sign bit mask.
#define xnn_load_impl(x) xnn_loadu_f32(x)
#define xnn_load_tail_impl(x, c) xnn_load_tail_f32(x, c)
#define xnn_load_tail_safe_impl(x, c) xnn_load_tail_safe_f32(x, c)
#define xnn_pre_store_impl(x) x

#include "src/xnnpack/common.h"
#include "src/xnnpack/math.h"
#include "src/xnnpack/microparams.h"
#include "src/xnnpack/simd/f32-avx.h"

void xnn_f32_maxpool_minmax_ukernel_9p__avx_u8(
size_t output_pixels,
size_t kernel_elements,
size_t channels,
const float** input,
size_t input_offset,
size_t input_pixel_stride,
float* output,
size_t input_increment,
size_t output_increment,
const struct xnn_f32_minmax_params* restrict params)
{
assert(output_pixels != 0);
assert(channels != 0);

const xnn_simd_f32_t vmin = xnn_set1_f32(params->scalar.min);
const xnn_simd_f32_t vmax = xnn_set1_f32(params->scalar.max);
XNN_FORCE_REALIZATION(vmin);
XNN_FORCE_REALIZATION(vmax);

do {
const float** i = (const float**) input;

// First pass: load the inputs, store the max pool in the output.
const float* i0 = *i++;
const float* i1 = 1 < kernel_elements ? *i++ : i0;
const float* i2 = 2 < kernel_elements ? *i++ : i0;
const float* i3 = 3 < kernel_elements ? *i++ : i0;
const float* i4 = 4 < kernel_elements ? *i++ : i0;
const float* i5 = 5 < kernel_elements ? *i++ : i0;
const float* i6 = 6 < kernel_elements ? *i++ : i0;
const float* i7 = 7 < kernel_elements ? *i++ : i0;
const float* i8 = 8 < kernel_elements ? *i++ : i0;
i0 = (const float*) ((uintptr_t) i0 + input_offset);
i1 = (const float*) ((uintptr_t) i1 + input_offset);
i2 = (const float*) ((uintptr_t) i2 + input_offset);
i3 = (const float*) ((uintptr_t) i3 + input_offset);
i4 = (const float*) ((uintptr_t) i4 + input_offset);
i5 = (const float*) ((uintptr_t) i5 + input_offset);
i6 = (const float*) ((uintptr_t) i6 + input_offset);
i7 = (const float*) ((uintptr_t) i7 + input_offset);
i8 = (const float*) ((uintptr_t) i8 + input_offset);

float* o = (float*) output;
size_t c = channels;
for (; c >= 8; c -= 8) {
const xnn_simd_f32_t vi0 = xnn_load_impl(i0); i0 += 8;
const xnn_simd_f32_t vi1 = xnn_load_impl(i1); i1 += 8;
const xnn_simd_f32_t vi2 = xnn_load_impl(i2); i2 += 8;
const xnn_simd_f32_t vi3 = xnn_load_impl(i3); i3 += 8;
const xnn_simd_f32_t vi4 = xnn_load_impl(i4); i4 += 8;
const xnn_simd_f32_t vi5 = xnn_load_impl(i5); i5 += 8;
const xnn_simd_f32_t vi6 = xnn_load_impl(i6); i6 += 8;
const xnn_simd_f32_t vi7 = xnn_load_impl(i7); i7 += 8;
const xnn_simd_f32_t vi8 = xnn_load_impl(i8); i8 += 8;

const xnn_simd_f32_t vmax018 = xnn_max_f32(xnn_max_f32(vi0, vi1), vi8);
const xnn_simd_f32_t vmax23 = xnn_max_f32(vi2, vi3);
const xnn_simd_f32_t vmax45 = xnn_max_f32(vi4, vi5);
const xnn_simd_f32_t vmax67 = xnn_max_f32(vi6, vi7);

const xnn_simd_f32_t vmax2345 = xnn_max_f32(vmax23, vmax45);
const xnn_simd_f32_t vmax01678 = xnn_max_f32(vmax018, vmax67);
xnn_simd_f32_t vacc = xnn_max_f32(vmax2345, vmax01678);

vacc = xnn_max_f32(vacc, vmin);
vacc = xnn_min_f32(vacc, vmax);

vacc = xnn_pre_store_impl(vacc);

xnn_storeu_f32(o, vacc); o += 8;
}
if (c > 0) {
const xnn_simd_f32_t vi0 = xnn_load_tail_impl(i0, c);
const xnn_simd_f32_t vi1 = xnn_load_tail_impl(i1, c);
const xnn_simd_f32_t vi2 = xnn_load_tail_impl(i2, c);
const xnn_simd_f32_t vi3 = xnn_load_tail_impl(i3, c);
const xnn_simd_f32_t vi4 = xnn_load_tail_impl(i4, c);
const xnn_simd_f32_t vi5 = xnn_load_tail_impl(i5, c);
const xnn_simd_f32_t vi6 = xnn_load_tail_impl(i6, c);
const xnn_simd_f32_t vi7 = xnn_load_tail_impl(i7, c);
const xnn_simd_f32_t vi8 = xnn_load_tail_impl(i8, c);

const xnn_simd_f32_t vmax018 = xnn_max_f32(xnn_max_f32(vi0, vi1), vi8);
const xnn_simd_f32_t vmax23 = xnn_max_f32(vi2, vi3);
const xnn_simd_f32_t vmax45 = xnn_max_f32(vi4, vi5);
const xnn_simd_f32_t vmax67 = xnn_max_f32(vi6, vi7);

const xnn_simd_f32_t vmax2345 = xnn_max_f32(vmax23, vmax45);
const xnn_simd_f32_t vmax01678 = xnn_max_f32(vmax018, vmax67);
xnn_simd_f32_t vacc = xnn_max_f32(vmax2345, vmax01678);

vacc = xnn_max_f32(vacc, vmin);
vacc = xnn_min_f32(vacc, vmax);

vacc = xnn_pre_store_impl(vacc);

xnn_store_tail_f32(o, vacc, c); o += c;
}

// Passes 1 - n: Max more inputs to the output.
o = (float*) output;
for (ptrdiff_t k = (ptrdiff_t) kernel_elements - 9; k > 0; k -= 9) {
const float* i0 = *i++;
const float* i1 = 1 < k ? *i++ : i0;
const float* i2 = 2 < k ? *i++ : i0;
const float* i3 = 3 < k ? *i++ : i0;
const float* i4 = 4 < k ? *i++ : i0;
const float* i5 = 5 < k ? *i++ : i0;
const float* i6 = 6 < k ? *i++ : i0;
const float* i7 = 7 < k ? *i++ : i0;
const float* i8 = 8 < k ? *i++ : i0;
i0 = (const float*) ((uintptr_t) i0 + input_offset);
i1 = (const float*) ((uintptr_t) i1 + input_offset);
i2 = (const float*) ((uintptr_t) i2 + input_offset);
i3 = (const float*) ((uintptr_t) i3 + input_offset);
i4 = (const float*) ((uintptr_t) i4 + input_offset);
i5 = (const float*) ((uintptr_t) i5 + input_offset);
i6 = (const float*) ((uintptr_t) i6 + input_offset);
i7 = (const float*) ((uintptr_t) i7 + input_offset);
i8 = (const float*) ((uintptr_t) i8 + input_offset);

float* o = (float*) output;
size_t c = channels;
for (; c >= 8; c -= 8) {
const xnn_simd_f32_t vi0 = xnn_load_impl(i0); i0 += 8;
const xnn_simd_f32_t vi1 = xnn_load_impl(i1); i1 += 8;
const xnn_simd_f32_t vi2 = xnn_load_impl(i2); i2 += 8;
const xnn_simd_f32_t vi3 = xnn_load_impl(i3); i3 += 8;
const xnn_simd_f32_t vi4 = xnn_load_impl(i4); i4 += 8;
const xnn_simd_f32_t vi5 = xnn_load_impl(i5); i5 += 8;
const xnn_simd_f32_t vi6 = xnn_load_impl(i6); i6 += 8;
const xnn_simd_f32_t vi7 = xnn_load_impl(i7); i7 += 8;
const xnn_simd_f32_t vi8 = xnn_load_impl(i8); i8 += 8;
const xnn_simd_f32_t vprev = xnn_load_impl(o);

const xnn_simd_f32_t vmax018 = xnn_max_f32(xnn_max_f32(vi0, vi1), vi8);
const xnn_simd_f32_t vmax23 = xnn_max_f32(vi2, vi3);
const xnn_simd_f32_t vmax45 = xnn_max_f32(vi4, vi5);
const xnn_simd_f32_t vmax67 = xnn_max_f32(vi6, vi7);

const xnn_simd_f32_t vmax2345 = xnn_max_f32(vmax23, vmax45);
const xnn_simd_f32_t vmax01678 = xnn_max_f32(vmax018, vmax67);
const xnn_simd_f32_t vmax012345678 = xnn_max_f32(vmax2345, vmax01678);

xnn_simd_f32_t vacc = xnn_max_f32(vprev, vmax012345678);

vacc = xnn_min_f32(vacc, vmax);

vacc = xnn_pre_store_impl(vacc);

xnn_storeu_f32(o, vacc); o += 8;
}
if (c > 0) {
const xnn_simd_f32_t vi0 = xnn_load_tail_impl(i0, c);
const xnn_simd_f32_t vi1 = xnn_load_tail_impl(i1, c);
const xnn_simd_f32_t vi2 = xnn_load_tail_impl(i2, c);
const xnn_simd_f32_t vi3 = xnn_load_tail_impl(i3, c);
const xnn_simd_f32_t vi4 = xnn_load_tail_impl(i4, c);
const xnn_simd_f32_t vi5 = xnn_load_tail_impl(i5, c);
const xnn_simd_f32_t vi6 = xnn_load_tail_impl(i6, c);
const xnn_simd_f32_t vi7 = xnn_load_tail_impl(i7, c);
const xnn_simd_f32_t vi8 = xnn_load_tail_impl(i8, c);
const xnn_simd_f32_t vprev = xnn_load_tail_safe_impl(o, c);

const xnn_simd_f32_t vmax018 = xnn_max_f32(xnn_max_f32(vi0, vi1), vi8);
const xnn_simd_f32_t vmax23 = xnn_max_f32(vi2, vi3);
const xnn_simd_f32_t vmax45 = xnn_max_f32(vi4, vi5);
const xnn_simd_f32_t vmax67 = xnn_max_f32(vi6, vi7);

const xnn_simd_f32_t vmax2345 = xnn_max_f32(vmax23, vmax45);
const xnn_simd_f32_t vmax01678 = xnn_max_f32(vmax018, vmax67);
const xnn_simd_f32_t vmax012345678 = xnn_max_f32(vmax2345, vmax01678);

xnn_simd_f32_t vacc = xnn_max_f32(vprev, vmax012345678);

vacc = xnn_min_f32(vacc, vmax);

vacc = xnn_pre_store_impl(vacc);

xnn_store_tail_f32(o, vacc, c);
}
}

input = (const float**) ((uintptr_t) input + input_increment);
input_offset += input_pixel_stride;
output = (float*) ((uintptr_t) output + output_increment);
} while (--output_pixels != 0);
}
Loading