Skip to content

Commit aedee69

Browse files
Improve SIMD usage of Oklab conversions (image-rs#75)
1 parent afa75f3 commit aedee69

File tree

1 file changed

+120
-88
lines changed

1 file changed

+120
-88
lines changed

src/color/oklab.rs

Lines changed: 120 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,83 @@
11
use glam::Vec3A;
22

33
trait Operations {
4-
fn srgb_to_linear(c: f32) -> f32;
5-
fn linear_to_srgb(c: f32) -> f32;
6-
fn cbrt(x: f32) -> f32;
4+
fn srgb_to_linear(c: Vec3A) -> Vec3A;
5+
fn linear_to_srgb(c: Vec3A) -> Vec3A;
6+
fn cbrt(x: Vec3A) -> Vec3A;
77
}
88

99
struct Reference;
1010
impl Operations for Reference {
11-
fn srgb_to_linear(c: f32) -> f32 {
12-
if c >= 0.04045 {
13-
((c + 0.055) / 1.055).powf(2.4)
14-
} else {
15-
c / 12.92
11+
fn srgb_to_linear(c: Vec3A) -> Vec3A {
12+
fn srgb_to_linear(c: f32) -> f32 {
13+
if c >= 0.04045 {
14+
((c + 0.055) / 1.055).powf(2.4)
15+
} else {
16+
c / 12.92
17+
}
1618
}
19+
20+
Vec3A::new(
21+
srgb_to_linear(c.x),
22+
srgb_to_linear(c.y),
23+
srgb_to_linear(c.z),
24+
)
1725
}
18-
fn linear_to_srgb(c: f32) -> f32 {
19-
if c > 0.0031308 {
20-
1.055 * c.powf(1.0 / 2.4) - 0.055
21-
} else {
22-
12.92 * c
26+
fn linear_to_srgb(c: Vec3A) -> Vec3A {
27+
fn linear_to_srgb(c: f32) -> f32 {
28+
if c > 0.0031308 {
29+
1.055 * c.powf(1.0 / 2.4) - 0.055
30+
} else {
31+
12.92 * c
32+
}
2333
}
34+
35+
Vec3A::new(
36+
linear_to_srgb(c.x),
37+
linear_to_srgb(c.y),
38+
linear_to_srgb(c.z),
39+
)
2440
}
25-
fn cbrt(x: f32) -> f32 {
26-
f32::cbrt(x)
41+
fn cbrt(x: Vec3A) -> Vec3A {
42+
Vec3A::new(x.x.cbrt(), x.y.cbrt(), x.z.cbrt())
2743
}
2844
}
2945

3046
struct Fast;
3147
impl Operations for Fast {
32-
fn srgb_to_linear(c: f32) -> f32 {
33-
if c >= 0.04045 {
34-
// This uses a Padé approximant for ((c + 0.055) / 1.055) ^ 2.4:
35-
// (0.000857709 +0.0359438 x+0.524293 x^2+1.31193 x^3)/(1+0.992498 x-0.119725 x^2)
36-
let c2 = c * c;
37-
let c3 = c2 * c;
38-
f32::min(
39-
1.0,
40-
(0.000857709 + 0.0359438 * c + 0.524293 * c2 + 1.31193 * c3)
41-
/ (1.0 + 0.992498 * c - 0.119725 * c2),
42-
)
43-
} else {
44-
c * (1.0 / 12.92)
45-
}
48+
fn srgb_to_linear(c: Vec3A) -> Vec3A {
49+
Vec3A::select(
50+
c.cmpge(Vec3A::splat(0.04045)),
51+
{
52+
// This uses a Padé approximant for ((c + 0.055) / 1.055) ^ 2.4:
53+
// (0.000857709 +0.0359438 x+0.524293 x^2+1.31193 x^3)/(1+0.992498 x-0.119725 x^2)
54+
let c2 = c * c;
55+
let c3 = c2 * c;
56+
Vec3A::min(
57+
Vec3A::ONE,
58+
(0.000857709 + 0.0359438 * c + 0.524293 * c2 + 1.31193 * c3)
59+
/ (Vec3A::ONE + 0.992498 * c - 0.119725 * c2),
60+
)
61+
},
62+
c * (1.0 / 12.92),
63+
)
4664
}
47-
fn linear_to_srgb(c: f32) -> f32 {
48-
if c > 0.0031308 {
49-
// This uses a Padé approximant for 1.055 c^(1/2.4) - 0.055:
50-
// (-0.0117264+21.0897 x+949.46 x^2+2225.62 x^3)/(1+176.398 x+1983.15 x^2+1035.65 x^3)
51-
let c2 = c * c;
52-
let c3 = c2 * c;
53-
(-0.0117264 + 21.0897 * c + 949.46 * c2 + 2225.62 * c3)
54-
/ (1.0 + 176.398 * c + 1983.15 * c2 + 1035.65 * c3)
55-
} else {
56-
12.92 * c
57-
}
65+
fn linear_to_srgb(c: Vec3A) -> Vec3A {
66+
Vec3A::select(
67+
c.cmpgt(Vec3A::splat(0.0031308)),
68+
{
69+
// This uses a Padé approximant for 1.055 c^(1/2.4) - 0.055:
70+
// (-0.0117264+21.0897 x+949.46 x^2+2225.62 x^3)/(1+176.398 x+1983.15 x^2+1035.65 x^3)
71+
let c2 = c * c;
72+
let c3 = c2 * c;
73+
(-0.0117264 + 21.0897 * c + 949.46 * c2 + 2225.62 * c3)
74+
/ (1.0 + 176.398 * c + 1983.15 * c2 + 1035.65 * c3)
75+
},
76+
c * 12.92,
77+
)
5878
}
5979
#[allow(clippy::excessive_precision)]
60-
fn cbrt(x: f32) -> f32 {
80+
fn cbrt(x: Vec3A) -> Vec3A {
6181
// This is the fast cbrt approximation from the oklab crate.
6282
// Source: https://gitlab.com/kornelski/oklab/-/blob/d3c074f154187dd5c0642119a6402a6c0753d70c/oklab/src/lib.rs#L61
6383
// Author: Kornel (https://gitlab.com/kornelski/)
@@ -68,55 +88,52 @@ impl Operations for Fast {
6888
const F: f32 = 1.6071428061e+0;
6989
const G: f32 = 3.5714286566e-1;
7090

71-
let mut t = f32::from_bits((x.to_bits() / 3).wrapping_add(B));
91+
let mut t = Vec3A::from_array(
92+
x.to_array()
93+
.map(|x| f32::from_bits((x.to_bits() / 3).wrapping_add(B))),
94+
);
7295
let s = C + (t * t) * (t / x);
7396
t *= G + F / (s + E + D / s);
7497
t
7598
}
7699
}
77100

78101
#[allow(clippy::excessive_precision)]
79-
fn srgb_to_oklab_impl<O: Operations>(rgb: Vec3A) -> Vec3A {
80-
let [r, g, b] = rgb.to_array().map(O::srgb_to_linear);
81-
82-
let mut l = 0.4122214708 * r + 0.5363325363 * g + 0.0514459929 * b;
83-
let mut m = 0.2119034982 * r + 0.6806995451 * g + 0.1073969566 * b;
84-
let mut s = 0.0883024619 * r + 0.2817188376 * g + 0.6299787005 * b;
85-
86-
l = O::cbrt(l);
87-
m = O::cbrt(m);
88-
s = O::cbrt(s);
89-
90-
let l_final = l * 0.2104542553 + m * 0.7936177850 + s * -0.0040720468;
91-
let a = l * 1.9779984951 + m * -2.4285922050 + s * 0.4505937099;
92-
let b = l * 0.0259040371 + m * 0.7827717662 + s * -0.8086757660;
102+
fn srgb_to_oklab_impl<O: Operations>(srgb: Vec3A) -> Vec3A {
103+
let rgb = O::srgb_to_linear(srgb);
104+
105+
let lms = Vec3A::new(
106+
rgb.dot(Vec3A::new(0.4122214708, 0.5363325363, 0.0514459929)),
107+
rgb.dot(Vec3A::new(0.2119034982, 0.6806995451, 0.1073969566)),
108+
rgb.dot(Vec3A::new(0.0883024619, 0.2817188376, 0.6299787005)),
109+
);
110+
let lms = O::cbrt(lms);
111+
112+
let lab = Vec3A::new(
113+
lms.dot(Vec3A::new(0.2104542553, 0.7936177850, -0.0040720468)),
114+
lms.dot(Vec3A::new(1.9779984951, -2.4285922050, 0.4505937099)),
115+
lms.dot(Vec3A::new(0.0259040371, 0.7827717662, -0.8086757660)),
116+
);
93117

94118
// normalize everything to the 0..1 range
95-
Vec3A::new(l_final, a + 0.5, b + 0.5)
119+
lab + Vec3A::new(0.0, 0.5, 0.5)
96120
}
97121
#[allow(clippy::excessive_precision)]
98122
fn oklab_to_srgb_impl<O: Operations>(lab: Vec3A) -> Vec3A {
99-
let l_org = lab.x;
100-
let a = lab.y - 0.5;
101-
let b = lab.z - 0.5;
102-
103-
let mut l = l_org + a * 0.3963377774 + b * 0.2158037573;
104-
let mut m = l_org + a * -0.1055613458 + b * -0.0638541728;
105-
let mut s = l_org + a * -0.0894841775 + b * -1.2914855480;
106-
107-
l = l * l * l;
108-
m = m * m * m;
109-
s = s * s * s;
110-
111-
let r = l * 4.0767416621 + m * -3.3077115913 + s * 0.2309699292;
112-
let g = l * -1.2684380046 + m * 2.6097574011 + s * -0.3413193965;
113-
let b = l * -0.0041960863 + m * -0.7034186147 + s * 1.7076147010;
114-
115-
Vec3A::new(
116-
O::linear_to_srgb(r),
117-
O::linear_to_srgb(g),
118-
O::linear_to_srgb(b),
119-
)
123+
let lab_norm = lab - Vec3A::new(0.0, 0.5, 0.5);
124+
let lms = Vec3A::new(
125+
lab_norm.dot(Vec3A::new(1.0, 0.3963377774, 0.2158037573)),
126+
lab_norm.dot(Vec3A::new(1.0, -0.1055613458, -0.0638541728)),
127+
lab_norm.dot(Vec3A::new(1.0, -0.0894841775, -1.2914855480)),
128+
);
129+
let lms = lms * lms * lms; // lms^3
130+
let rgb = Vec3A::new(
131+
lms.dot(Vec3A::new(4.0767416621, -3.3077115913, 0.2309699292)),
132+
lms.dot(Vec3A::new(-1.2684380046, 2.6097574011, -0.3413193965)),
133+
lms.dot(Vec3A::new(-0.0041960863, -0.7034186147, 1.7076147010)),
134+
);
135+
136+
O::linear_to_srgb(rgb)
120137
}
121138

122139
#[allow(unused)]
@@ -200,43 +217,58 @@ mod tests {
200217
}
201218
}
202219

220+
pub struct Scalar<O>(O);
221+
impl<O: Operations> Scalar<O> {
222+
fn srgb_to_linear(c: f32) -> f32 {
223+
O::srgb_to_linear(Vec3A::splat(c)).x
224+
}
225+
fn linear_to_srgb(c: f32) -> f32 {
226+
O::linear_to_srgb(Vec3A::splat(c)).x
227+
}
228+
fn cbrt(x: f32) -> f32 {
229+
O::cbrt(Vec3A::splat(x)).x
230+
}
231+
}
232+
type RefScalar = Scalar<Reference>;
233+
type FastScalar = Scalar<Fast>;
234+
203235
#[test]
204236
fn test_linear_srgb() {
205237
for c in 0..=255 {
206238
let c = c as f32 / 255.0;
207-
let l = Reference::srgb_to_linear(c);
208-
let c2 = Reference::linear_to_srgb(l);
239+
let l = RefScalar::srgb_to_linear(c);
240+
let c2 = RefScalar::linear_to_srgb(l);
209241

210242
assert!((c - c2).abs() < 1e-6, "{c} -> {c2}");
211243
}
212244

213245
for c in 0..=255 {
214246
let c = c as f32 / 255.0;
215-
let l = Fast::srgb_to_linear(c);
216-
let c2 = Fast::linear_to_srgb(l);
247+
let l = FastScalar::srgb_to_linear(c);
248+
let c2 = FastScalar::linear_to_srgb(l);
217249

218250
assert!((c - c2).abs() < 2.5e-3, "{c} -> {c2}");
219251
assert!((0.0..=1.0).contains(&l), "{c} -> {l}");
220252
assert!((0.0..=1.0).contains(&c2), "{c} -> {l}");
221253
}
222254

223-
assert_eq!(Reference::srgb_to_linear(0.0), 0.0);
224-
assert!((Reference::srgb_to_linear(1.0) - 1.0).abs() < 1e-6);
225-
assert_eq!(Fast::linear_to_srgb(0.0), 0.0);
226-
assert!((Fast::srgb_to_linear(1.0) - 1.0).abs() < 1e-6);
255+
assert_eq!(RefScalar::srgb_to_linear(0.0), 0.0);
256+
assert!((RefScalar::srgb_to_linear(1.0) - 1.0).abs() < 1e-6);
257+
assert_eq!(FastScalar::linear_to_srgb(0.0), 0.0);
258+
assert!((FastScalar::srgb_to_linear(1.0) - 1.0).abs() < 1e-6);
227259
}
228260

229261
#[test]
230262
fn test_error_fast_srgb_to_linear() {
231263
assert_eq!(
232-
get_error_stats(Reference::srgb_to_linear, Fast::srgb_to_linear),
264+
get_error_stats(RefScalar::srgb_to_linear, FastScalar::srgb_to_linear),
233265
"Error: avg=0.00002514 max=0.00013047 for 0.999"
234266
);
235267
}
236268
#[test]
237269
fn test_error_fast_linear_to_srgb() {
238270
assert_eq!(
239-
get_error_stats(Reference::linear_to_srgb, Fast::linear_to_srgb),
271+
get_error_stats(RefScalar::linear_to_srgb, FastScalar::linear_to_srgb),
240272
"Error: avg=0.00105457 max=0.00236702 for 0.732"
241273
);
242274
}

0 commit comments

Comments
 (0)