Skip to content

Commit a9dd5b8

Browse files
gonnetxnnpack-bot
authored andcommitted
Add rms-norm op for fp32 and fp16.
PiperOrigin-RevId: 798064116
1 parent b79094f commit a9dd5b8

File tree

97 files changed

+20449
-107
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+20449
-107
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ SET(OPERATOR_SRCS
501501
src/operators/dynamic-fully-connected-nc.c
502502
src/operators/fully-connected-nc.c
503503
src/operators/max-pooling-nhwc.c
504+
src/operators/normalize-nc.c
504505
src/operators/pack-lh.c
505506
src/operators/reduce-nd.c
506507
src/operators/resize-bilinear-nchw.c
@@ -535,6 +536,7 @@ SET(SUBGRAPH_SRCS
535536
src/subgraph/fully-connected-sparse.c
536537
src/subgraph/fully-connected.c
537538
src/subgraph/max-pooling-2d.c
539+
src/subgraph/normalize.c
538540
src/subgraph/pack-lh.c
539541
src/subgraph/reshape-helpers.c
540542
src/subgraph/rope.c

bench/BUILD.bazel

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ xnnpack_cc_library(
9797

9898
xnnpack_cxx_library(
9999
name = "gemm_benchmark",
100+
testonly = True,
100101
srcs = [
101102
"gemm-benchmark.cc",
102103
],
@@ -115,6 +116,7 @@ xnnpack_cxx_library(
115116

116117
xnnpack_cxx_library(
117118
name = "packw_benchmark",
119+
testonly = True,
118120
hdrs = [
119121
"packw-benchmark.h",
120122
],
@@ -126,6 +128,7 @@ xnnpack_cxx_library(
126128

127129
xnnpack_cxx_library(
128130
name = "bgemm",
131+
testonly = True,
129132
hdrs = [
130133
"bgemm.h",
131134
],
@@ -134,6 +137,19 @@ xnnpack_cxx_library(
134137
],
135138
)
136139

140+
xnnpack_cxx_library(
141+
name = "packq_benchmark",
142+
testonly = True,
143+
srcs = [
144+
"packq-benchmark.cc",
145+
],
146+
hdrs = ["packq-benchmark.h"],
147+
deps = MICROKERNEL_BENCHMARK_DEPS + [
148+
":bgemm",
149+
"@com_google_benchmark//:benchmark",
150+
],
151+
)
152+
137153
######################### Benchmarks for micro-kernels #########################
138154

139155
[xnnpack_benchmark(
@@ -275,8 +291,10 @@ xnnpack_benchmark(
275291
"f32_vcmul",
276292
"rdminmax",
277293
"rdsum",
294+
"rdsum2",
278295
"rminmax",
279296
"rsum",
297+
"rsum2",
280298
"x8_lut",
281299
]]
282300

@@ -453,18 +471,6 @@ xnnpack_benchmark(
453471
],
454472
)
455473

456-
xnnpack_cxx_library(
457-
name = "packq_benchmark",
458-
srcs = [
459-
"packq-benchmark.cc",
460-
],
461-
hdrs = ["packq-benchmark.h"],
462-
deps = MICROKERNEL_BENCHMARK_DEPS + [
463-
":bgemm",
464-
"@com_google_benchmark//:benchmark",
465-
],
466-
)
467-
468474
xnnpack_benchmark(
469475
name = "x8_packq_bench",
470476
srcs = [

bench/rdsum2.cc

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <random>
9+
10+
#include "bench/utils.h"
11+
#include "src/xnnpack/buffer.h"
12+
#include "src/xnnpack/common.h"
13+
#include "src/xnnpack/hardware-config.h" // IWYU pragma: keep
14+
#include "src/xnnpack/reduce.h" // IWYU pragma: keep
15+
#include <benchmark/benchmark.h>
16+
17+
// Microkernel function, templated on the `params` type.
18+
template <typename Input, typename Output, typename UKernelParams>
19+
using UKernelFn = void (*)(size_t, size_t, size_t, size_t, const Input*, size_t,
20+
size_t, size_t, const Input*, Output*,
21+
const UKernelParams*);
22+
23+
template <typename Input, typename Output, typename UKernelParams>
24+
static void reduce(benchmark::State& state, uint64_t arch_flags,
25+
UKernelFn<Input, Output, UKernelParams> ukernel) {
26+
if (!benchmark::utils::CheckArchFlags(state, arch_flags)) {
27+
return;
28+
}
29+
30+
const size_t channels = state.range(0);
31+
const size_t rows = state.range(1);
32+
33+
std::random_device random_device;
34+
auto rng = std::mt19937(random_device());
35+
36+
xnnpack::Buffer<Input, XNN_ALLOCATION_ALIGNMENT> input(
37+
channels * rows, xnnpack::XnnExtraBytes);
38+
xnnpack::Buffer<Input, XNN_ALLOCATION_ALIGNMENT> zero(channels, 0,
39+
xnnpack::XnnExtraBytes);
40+
xnnpack::fill_uniform_random_bits(input.data(), input.size(), rng);
41+
xnnpack::Buffer<Output, XNN_ALLOCATION_ALIGNMENT> output(channels);
42+
43+
UKernelParams params;
44+
memset(&params, 0, sizeof(params));
45+
46+
for (auto _ : state) {
47+
ukernel(channels, rows, 1, 1, input.data(), channels * sizeof(Input), 0, 0,
48+
zero.data(), output.data(), &params);
49+
}
50+
51+
const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
52+
if (cpu_frequency != 0) {
53+
state.counters["cpufreq"] = cpu_frequency;
54+
}
55+
56+
const size_t elements_per_iteration = channels * rows;
57+
state.counters["elements"] = benchmark::Counter(
58+
static_cast<uint64_t>(state.iterations()) * elements_per_iteration,
59+
benchmark::Counter::kIsRate);
60+
61+
const size_t bytes_per_iteration = channels * rows * sizeof(Input);
62+
state.counters["bytes"] = benchmark::Counter(
63+
static_cast<uint64_t>(state.iterations()) * bytes_per_iteration,
64+
benchmark::Counter::kIsRate);
65+
}
66+
67+
#define XNN_UKERNEL(arch_flags, ukernel, row_tile, batch_tile, vector_tile, \
68+
datatype_in, datatype_out, params_type, init_params) \
69+
BENCHMARK_CAPTURE(reduce, ukernel, arch_flags, ukernel) \
70+
->Apply(benchmark::utils::ReduceDiscontiguousParameters<datatype_in>) \
71+
->UseRealTime();
72+
// #include "src/f16-f32acc-rdsum/f16-f32acc-rdsum.inc"
73+
#include "src/f32-rdsum/f32-rdsum.inc"
74+
#undef XNN_UKERNEL
75+
76+
#ifndef XNNPACK_BENCHMARK_NO_MAIN
77+
XNN_BENCHMARK_MAIN();
78+
#endif

bench/rsum2.cc

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <random>
9+
10+
#include "bench/utils.h"
11+
#include "src/xnnpack/buffer.h"
12+
#include "src/xnnpack/common.h"
13+
#include "src/xnnpack/hardware-config.h" // IWYU pragma: keep
14+
#include "src/xnnpack/reduce.h" // IWYU pragma: keep
15+
#include <benchmark/benchmark.h>
16+
17+
// Microkernel function, templated on the `params` type.
18+
template <typename Input, typename Output, typename UKernelParams>
19+
using UKernelFn = void (*)(size_t, const Input*, Output*, const UKernelParams*);
20+
21+
template <typename Input, typename Output, typename UKernelParams>
22+
static void reduce(benchmark::State& state, uint64_t arch_flags,
23+
UKernelFn<Input, Output, UKernelParams> ukernel) {
24+
if (!benchmark::utils::CheckArchFlags(state, arch_flags)) {
25+
return;
26+
}
27+
28+
const size_t channels = state.range(0);
29+
const size_t rows = state.range(1);
30+
31+
std::random_device random_device;
32+
auto rng = std::mt19937(random_device());
33+
34+
xnnpack::Buffer<Input, XNN_ALLOCATION_ALIGNMENT> input(
35+
channels * rows, xnnpack::XnnExtraBytes);
36+
xnnpack::fill_uniform_random_bits(input.data(), input.size(), rng);
37+
38+
UKernelParams params;
39+
memset(&params, 0, sizeof(params));
40+
41+
Output output = 0;
42+
for (auto _ : state) {
43+
for (size_t r = 0; r < rows; ++r) {
44+
ukernel(channels * sizeof(Input), input.data() + r * channels, &output,
45+
&params);
46+
}
47+
}
48+
49+
const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
50+
if (cpu_frequency != 0) {
51+
state.counters["cpufreq"] = cpu_frequency;
52+
}
53+
54+
const size_t elements_per_iteration = rows * channels;
55+
state.counters["elements"] = benchmark::Counter(
56+
static_cast<uint64_t>(state.iterations()) * elements_per_iteration,
57+
benchmark::Counter::kIsRate);
58+
59+
const size_t bytes_per_iteration = rows * channels * sizeof(Input);
60+
state.counters["bytes"] = benchmark::Counter(
61+
static_cast<uint64_t>(state.iterations()) * bytes_per_iteration,
62+
benchmark::Counter::kIsRate);
63+
}
64+
65+
#define XNN_UKERNEL(arch_flags, ukernel, batch_tile, vector_tile, datatype_in, \
66+
datatype_out, params_type, init_params) \
67+
BENCHMARK_CAPTURE(reduce, ukernel, arch_flags, ukernel) \
68+
->Apply(benchmark::utils::ReduceParameters<datatype_in>) \
69+
->UseRealTime();
70+
#include "src/f16-f32acc-rsum2/f16-f32acc-rsum2.inc"
71+
#include "src/f32-rsum2/f32-rsum2.inc"
72+
#undef XNN_UKERNEL
73+
74+
#ifndef XNNPACK_BENCHMARK_NO_MAIN
75+
XNN_BENCHMARK_MAIN();
76+
#endif

bench/subgraph/fp32-l2-norm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ xnn_subgraph_t FP32L2Norm(size_t m, size_t n, size_t k, uint32_t norm_mask) {
9393

9494
uint32_t inv_sqrt_sum_sq = XNN_INVALID_VALUE_ID;
9595
status = xnn_define_tensor_value(
96-
subgraph, xnn_datatype_fp32, dims.size(), dims.data(),
96+
subgraph, xnn_datatype_fp32, reduction_dims.size(), reduction_dims.data(),
9797
/*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &inv_sqrt_sum_sq);
9898
if (status != xnn_status_success) {
9999
std::cerr << "failed to create tensor inv_sqrt_sum_sq" << std::endl;

bench/utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <algorithm>
1010
#include <cstddef>
1111
#include <cstdint>
12-
#include <functional>
1312
#include <string>
1413

1514
#include "src/xnnpack/common.h"

build_srcs.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ OPERATOR_SRCS = [
2525
"src/operators/dynamic-fully-connected-nc.c",
2626
"src/operators/fully-connected-nc.c",
2727
"src/operators/max-pooling-nhwc.c",
28+
"src/operators/normalize-nc.c",
2829
"src/operators/pack-lh.c",
2930
"src/operators/reduce-nd.c",
3031
"src/operators/resize-bilinear-nchw.c",
@@ -56,6 +57,7 @@ SUBGRAPH_SRCS = [
5657
"src/subgraph/fully-connected-sparse.c",
5758
"src/subgraph/fully-connected.c",
5859
"src/subgraph/max-pooling-2d.c",
60+
"src/subgraph/normalize.c",
5961
"src/subgraph/pack-lh.c",
6062
"src/subgraph/reshape-helpers.c",
6163
"src/subgraph/rope.c",
@@ -99,7 +101,9 @@ MICROKERNEL_DEFS = [
99101
"src/f16-dwconv/f16-dwconv-minmax.inc",
100102
"src/f16-f32-vcvt/f16-f32-vcvt.inc",
101103
"src/f16-f32acc-rdsum/f16-f32acc-rdsum.inc",
104+
"src/f16-f32acc-rdsum2/f16-f32acc-rdsum2.inc",
102105
"src/f16-f32acc-rsum/f16-f32acc-rsum.inc",
106+
"src/f16-f32acc-rsum2/f16-f32acc-rsum2.inc",
103107
"src/f16-maxpool/f16-maxpool-minmax.inc",
104108
"src/f16-qs8-vcvt/f16-qs8-vcvt.inc",
105109
"src/f16-qu8-vcvt/f16-qu8-vcvt.inc",
@@ -163,10 +167,12 @@ MICROKERNEL_DEFS = [
163167
"src/f32-rdminmax/f32-rdmax.inc",
164168
"src/f32-rdminmax/f32-rdmin.inc",
165169
"src/f32-rdsum/f32-rdsum.inc",
170+
"src/f32-rdsum2/f32-rdsum2.inc",
166171
"src/f32-rminmax/f32-rmax.inc",
167172
"src/f32-rminmax/f32-rmin.inc",
168173
"src/f32-rminmax/f32-rminmax.inc",
169174
"src/f32-rsum/f32-rsum.inc",
175+
"src/f32-rsum2/f32-rsum2.inc",
170176
"src/f32-spmm/f32-spmm-minmax.inc",
171177
"src/f32-vabs/f32-vabs.inc",
172178
"src/f32-vapproxgelu/f32-vapproxgelu.inc",

cmake/gen/avx512f_microkernels.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ SET(PROD_AVX512F_MICROKERNEL_SRCS
2727
src/f32-rdminmax/gen/f32-rdmax-2p2x-avx512f-u32.c
2828
src/f32-rdminmax/gen/f32-rdmin-2p2x-avx512f-u32.c
2929
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-u64.c
30+
src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx512f.c
3031
src/f32-rminmax/gen/f32-rmax-avx512f-u64-acc4.c
3132
src/f32-rminmax/gen/f32-rmin-avx512f-u64-acc4.c
3233
src/f32-rminmax/gen/f32-rminmax-avx512f-u64-acc4.c
3334
src/f32-rsum/gen/f32-rsum-avx512f-u32-acc2.c
35+
src/f32-rsum2/gen/f32-rsum2-avx512f-u16.c
3436
src/f32-vapproxgelu/gen/f32-vapproxgelu-avx512f-rational-12-10-div.c
3537
src/f32-vbinary/gen/f32-vadd-avx512f-u32.c
3638
src/f32-vbinary/gen/f32-vaddc-avx512f-u32.c

cmake/gen/avx512skx_microkernels.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
SET(PROD_AVX512SKX_MICROKERNEL_SRCS
1313
src/f16-f32-vcvt/gen/f16-f32-vcvt-avx512skx-u16.c
1414
src/f16-f32acc-rdsum/gen/f16-f32acc-rdsum-7p7x-avx512skx-u64.c
15+
src/f16-f32acc-rdsum2/gen/f16-f32acc-rdsum2-7p7x-avx512skx.c
1516
src/f16-f32acc-rsum/gen/f16-f32acc-rsum-avx512skx-u32-acc2.c
17+
src/f16-f32acc-rsum2/gen/f16-f32acc-rsum2-avx512skx.c
1618
src/f16-rminmax/gen/f16-rmax-avx512skx-u64-acc4.c
1719
src/f16-rminmax/gen/f16-rmin-avx512skx-u64-acc4.c
1820
src/f16-rminmax/gen/f16-rminmax-avx512skx-u64-acc4.c

cmake/gen/avx_microkernels.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ SET(PROD_AVX_MICROKERNEL_SRCS
3434
src/f32-rdminmax/gen/f32-rdmax-2p2x-avx-u32.c
3535
src/f32-rdminmax/gen/f32-rdmin-2p2x-avx-u32.c
3636
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx-u32.c
37+
src/f32-rdsum2/gen/f32-rdsum2-7p7x-minmax-avx.c
3738
src/f32-rminmax/gen/f32-rmax-avx-u32-acc4.c
3839
src/f32-rminmax/gen/f32-rmin-avx-u32-acc4.c
3940
src/f32-rminmax/gen/f32-rminmax-avx-u32-acc4.c
4041
src/f32-rsum/gen/f32-rsum-avx-u32-acc4.c
42+
src/f32-rsum2/gen/f32-rsum2-avx-u8.c
4143
src/f32-vapproxgelu/gen/f32-vapproxgelu-avx-rational-12-10-div.c
4244
src/f32-vbinary/gen/f32-vadd-avx-u16.c
4345
src/f32-vbinary/gen/f32-vaddc-avx-u16.c

0 commit comments

Comments
 (0)