Skip to content

Commit 73f0b6e

Browse files
committed
kmeans: Split k-independent inner loop into its own function
This reduces code size of the k-means implementation by 28%. Hot instructions are concentrated in a shared block which is only 5.1% of total k-means code bytes. Before: Size Name 1,277B rav1e::util::kmeans::kmeans 1,229B rav1e::util::kmeans::kmeans 1,163B rav1e::util::kmeans::kmeans 1,085B rav1e::util::kmeans::kmeans 1,080B rav1e::util::kmeans::kmeans 1,004B rav1e::util::kmeans::kmeans 6,838B filtered data size After: Size Name 1,001B rav1e::util::kmeans::kmeans 891B rav1e::util::kmeans::kmeans 768B rav1e::util::kmeans::kmeans 766B rav1e::util::kmeans::kmeans 669B rav1e::util::kmeans::kmeans 573B rav1e::util::kmeans::kmeans 251B rav1e::util::kmeans::scan 4,919B filtered data size
1 parent aa1cc7c commit 73f0b6e

File tree

1 file changed

+38
-27
lines changed

1 file changed

+38
-27
lines changed

src/util/kmeans.rs

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ where
2727
high[K - 1] = data.len();
2828
sum[K - 1] = means[K - 1].into();
2929

30-
let data_to = |n: usize| unsafe { data.get_unchecked(..n) }.iter();
31-
let data_from = |n: usize| unsafe { data.get_unchecked(n..) }.iter();
32-
3330
// Constrain complexity to O(n log n)
3431
let limit = 2 * (usize::BITS - data.len().leading_zeros());
3532
for _ in 0..limit {
@@ -40,31 +37,9 @@ where
4037
.zip(low.iter_mut().skip(1).zip(&mut high))
4138
.enumerate()
4239
{
43-
let mut n = *high;
44-
let mut s = sum[i];
45-
for &d in data_to(n).rev().take_while(|&d| *d > threshold) {
46-
s -= d.into();
47-
n -= 1;
48-
}
49-
for &d in data_from(n).take_while(|&d| *d <= threshold) {
50-
s += d.into();
51-
n += 1;
52-
}
53-
*high = n;
54-
sum[i] = s;
55-
56-
let mut n = *low;
57-
let mut s = sum[i + 1];
58-
for &d in data_from(n).take_while(|&d| *d < threshold) {
59-
s -= d.into();
60-
n += 1;
61-
}
62-
for &d in data_to(n).rev().take_while(|&d| *d >= threshold) {
63-
s += d.into();
64-
n -= 1;
40+
unsafe {
41+
scan(high, low, sum.get_unchecked_mut(i..=i + 1), data, threshold);
6542
}
66-
*low = n;
67-
sum[i + 1] = s;
6843
}
6944
let mut changed = false;
7045
for (((m, sum), high), low) in
@@ -90,6 +65,42 @@ where
9065
means
9166
}
9267

68+
#[inline(never)]
69+
unsafe fn scan<T>(
70+
high: &mut usize, low: &mut usize, sum: &mut [i64], data: &[T], t: T,
71+
) where
72+
T: Copy,
73+
T: Into<i64>,
74+
T: PartialEq,
75+
T: PartialOrd,
76+
{
77+
let mut n = *high;
78+
let mut s = *sum.get_unchecked(0);
79+
for &d in data.get_unchecked(..n).iter().rev().take_while(|&d| *d > t) {
80+
s -= d.into();
81+
n -= 1;
82+
}
83+
for &d in data.get_unchecked(n..).iter().take_while(|&d| *d <= t) {
84+
s += d.into();
85+
n += 1;
86+
}
87+
*high = n;
88+
*sum.get_unchecked_mut(0) = s;
89+
90+
let mut n = *low;
91+
let mut s = *sum.get_unchecked(1);
92+
for &d in data.get_unchecked(n..).iter().take_while(|&d| *d < t) {
93+
s -= d.into();
94+
n += 1;
95+
}
96+
for &d in data.get_unchecked(..n).iter().rev().take_while(|&d| *d >= t) {
97+
s += d.into();
98+
n -= 1;
99+
}
100+
*low = n;
101+
*sum.get_unchecked_mut(1) = s;
102+
}
103+
93104
#[cfg(test)]
94105
mod test {
95106
use super::*;

0 commit comments

Comments
 (0)