Skip to content

Commit 599a63e

Browse files
ritchie46nevi-me
authored andcommitted
ARROW-11428: [Rust] Add power_scalar kernel
Adds a SISD and SIMD kernel to raise a `Float32/64` array to a power of a `scalar` of the same type. We could also make a thin `sqrt` wrapper. I also added a `unary_op` fn to `ArrowNumeric` type as this seemed the most generic way to implement this. Next PR I could add support for a binary version of this (e.g. array to the power of array). _edit_: The `ArrowFloatNumericType` trait was added because the [Simd::powf](https://rust-lang.github.io/packed_simd/packed_simd_2/struct.Simd.html#method.powf-6) is only available for float arrays (e.g. `[f32, N]`, `[f64, N]`). However, the *packed_simd* crate doesn't expose this functionality via a trait, but directly on the type, hence the extra trait. Closes #9361 from ritchie46/power_kernel Authored-by: Ritchie Vink <[email protected]> Signed-off-by: Neville Dipale <[email protected]>
1 parent 1c219e3 commit 599a63e

File tree

2 files changed

+112
-23
lines changed

2 files changed

+112
-23
lines changed

rust/arrow/src/compute/kernels/arithmetic.rs

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,61 +30,80 @@ use num::{One, Zero};
3030
use crate::buffer::Buffer;
3131
#[cfg(simd)]
3232
use crate::buffer::MutableBuffer;
33-
use crate::compute::util::combine_option_bitmap;
33+
use crate::compute::{kernels::arity::unary, util::combine_option_bitmap};
3434
use crate::datatypes;
3535
use crate::datatypes::ArrowNumericType;
3636
use crate::error::{ArrowError, Result};
3737
use crate::{array::*, util::bit_util};
38+
use num::traits::Pow;
3839
#[cfg(simd)]
3940
use std::borrow::BorrowMut;
4041
#[cfg(simd)]
4142
use std::slice::{ChunksExact, ChunksExactMut};
4243

43-
/// Helper function to perform math lambda function on values from single array of signed numeric
44-
/// type. If value is null then the output value is also null, so `-null` is `null`.
45-
pub fn signed_unary_math_op<T, F>(
44+
/// SIMD vectorized version of `unary_math_op` above specialized for signed numerical values.
45+
#[cfg(simd)]
46+
fn simd_signed_unary_math_op<T, SIMD_OP, SCALAR_OP>(
4647
array: &PrimitiveArray<T>,
47-
op: F,
48+
simd_op: SIMD_OP,
49+
scalar_op: SCALAR_OP,
4850
) -> Result<PrimitiveArray<T>>
4951
where
5052
T: datatypes::ArrowSignedNumericType,
51-
T::Native: Neg<Output = T::Native>,
52-
F: Fn(T::Native) -> T::Native,
53+
SIMD_OP: Fn(T::SignedSimd) -> T::SignedSimd,
54+
SCALAR_OP: Fn(T::Native) -> T::Native,
5355
{
54-
let values = array.values().iter().map(|v| op(*v));
55-
// JUSTIFICATION
56-
// Benefit
57-
// ~60% speedup
58-
// Soundness
59-
// `values` is an iterator with a known size.
60-
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
56+
let lanes = T::lanes();
57+
let buffer_size = array.len() * std::mem::size_of::<T::Native>();
58+
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
59+
60+
let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
61+
let mut array_chunks = array.values().chunks_exact(lanes);
62+
63+
result_chunks
64+
.borrow_mut()
65+
.zip(array_chunks.borrow_mut())
66+
.for_each(|(result_slice, input_slice)| {
67+
let simd_input = T::load_signed(input_slice);
68+
let simd_result = T::signed_unary_op(simd_input, &simd_op);
69+
T::write_signed(simd_result, result_slice);
70+
});
71+
72+
let result_remainder = result_chunks.into_remainder();
73+
let array_remainder = array_chunks.remainder();
74+
75+
result_remainder.into_iter().zip(array_remainder).for_each(
76+
|(scalar_result, scalar_input)| {
77+
*scalar_result = scalar_op(*scalar_input);
78+
},
79+
);
6180

6281
let data = ArrayData::new(
6382
T::DATA_TYPE,
6483
array.len(),
6584
None,
6685
array.data_ref().null_buffer().cloned(),
6786
0,
68-
vec![buffer],
87+
vec![result.into()],
6988
vec![],
7089
);
7190
Ok(PrimitiveArray::<T>::from(Arc::new(data)))
7291
}
7392

74-
/// SIMD vectorized version of `signed_unary_math_op` above.
7593
#[cfg(simd)]
76-
fn simd_signed_unary_math_op<T, SIMD_OP, SCALAR_OP>(
94+
fn simd_float_unary_math_op<T, SIMD_OP, SCALAR_OP>(
7795
array: &PrimitiveArray<T>,
7896
simd_op: SIMD_OP,
7997
scalar_op: SCALAR_OP,
8098
) -> Result<PrimitiveArray<T>>
8199
where
82-
T: datatypes::ArrowSignedNumericType,
83-
SIMD_OP: Fn(T::SignedSimd) -> T::SignedSimd,
100+
T: datatypes::ArrowFloatNumericType,
101+
SIMD_OP: Fn(T::Simd) -> T::Simd,
84102
SCALAR_OP: Fn(T::Native) -> T::Native,
85103
{
86104
let lanes = T::lanes();
87105
let buffer_size = array.len() * std::mem::size_of::<T::Native>();
106+
88107
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
89108

90109
let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
@@ -94,9 +113,9 @@ where
94113
.borrow_mut()
95114
.zip(array_chunks.borrow_mut())
96115
.for_each(|(result_slice, input_slice)| {
97-
let simd_input = T::load_signed(input_slice);
98-
let simd_result = T::signed_unary_op(simd_input, &simd_op);
99-
T::write_signed(simd_result, result_slice);
116+
let simd_input = T::load(input_slice);
117+
let simd_result = T::unary_op(simd_input, &simd_op);
118+
T::write(simd_result, result_slice);
100119
});
101120

102121
let result_remainder = result_chunks.into_remainder();
@@ -536,7 +555,29 @@ where
536555
#[cfg(simd)]
537556
return simd_signed_unary_math_op(array, |x| -x, |x| -x);
538557
#[cfg(not(simd))]
539-
return signed_unary_math_op(array, |x| -x);
558+
return Ok(unary(array, |x| -x));
559+
}
560+
561+
/// Raise array with floating point values to the power of a scalar.
562+
pub fn powf_scalar<T>(
563+
array: &PrimitiveArray<T>,
564+
raise: T::Native,
565+
) -> Result<PrimitiveArray<T>>
566+
where
567+
T: datatypes::ArrowFloatNumericType,
568+
T::Native: Pow<T::Native, Output = T::Native>,
569+
{
570+
#[cfg(simd)]
571+
{
572+
let raise_vector = T::init(raise);
573+
return simd_float_unary_math_op(
574+
array,
575+
|x| T::pow(x, raise_vector),
576+
|x| x.pow(raise),
577+
);
578+
}
579+
#[cfg(not(simd))]
580+
return Ok(unary(array, |x| x.pow(raise)));
540581
}
541582

542583
/// Perform `left * right` operation on two arrays. If either left or right value is null
@@ -808,4 +849,16 @@ mod tests {
808849
.collect();
809850
assert_eq!(expected, actual);
810851
}
852+
853+
#[test]
854+
fn test_primitive_array_raise_power_scalar() {
855+
let a = Float64Array::from(vec![1.0, 2.0, 3.0]);
856+
let actual = powf_scalar(&a, 2.0).unwrap();
857+
let expected = Float64Array::from(vec![1.0, 4.0, 9.0]);
858+
assert_eq!(expected, actual);
859+
let a = Float64Array::from(vec![Some(1.0), None, Some(3.0)]);
860+
let actual = powf_scalar(&a, 2.0).unwrap();
861+
let expected = Float64Array::from(vec![Some(1.0), None, Some(9.0)]);
862+
assert_eq!(expected, actual);
863+
}
811864
}

rust/arrow/src/datatypes.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,8 @@ where
605605

606606
/// Writes a SIMD result back to a slice
607607
fn write(simd_result: Self::Simd, slice: &mut [Self::Native]);
608+
609+
fn unary_op<F: Fn(Self::Simd) -> Self::Simd>(a: Self::Simd, op: F) -> Self::Simd;
608610
}
609611

610612
#[cfg(not(simd))]
@@ -806,6 +808,14 @@ macro_rules! make_numeric_type {
806808
fn write(simd_result: Self::Simd, slice: &mut [Self::Native]) {
807809
unsafe { simd_result.write_to_slice_unaligned_unchecked(slice) };
808810
}
811+
812+
#[inline]
813+
fn unary_op<F: Fn(Self::Simd) -> Self::Simd>(
814+
a: Self::Simd,
815+
op: F,
816+
) -> Self::Simd {
817+
op(a)
818+
}
809819
}
810820

811821
#[cfg(not(simd))]
@@ -909,6 +919,32 @@ make_signed_numeric_type!(Int64Type, i64x8);
909919
make_signed_numeric_type!(Float32Type, f32x16);
910920
make_signed_numeric_type!(Float64Type, f64x8);
911921

922+
#[cfg(simd)]
923+
pub trait ArrowFloatNumericType: ArrowNumericType {
924+
fn pow(base: Self::Simd, raise: Self::Simd) -> Self::Simd;
925+
}
926+
927+
#[cfg(not(simd))]
928+
pub trait ArrowFloatNumericType: ArrowNumericType {}
929+
930+
macro_rules! make_float_numeric_type {
931+
($impl_ty:ty, $simd_ty:ident) => {
932+
#[cfg(simd)]
933+
impl ArrowFloatNumericType for $impl_ty {
934+
#[inline]
935+
fn pow(base: Self::Simd, raise: Self::Simd) -> Self::Simd {
936+
base.powf(raise)
937+
}
938+
}
939+
940+
#[cfg(not(simd))]
941+
impl ArrowFloatNumericType for $impl_ty {}
942+
};
943+
}
944+
945+
make_float_numeric_type!(Float32Type, f32x16);
946+
make_float_numeric_type!(Float64Type, f64x8);
947+
912948
/// A subtype of primitive type that represents temporal values.
913949
pub trait ArrowTemporalType: ArrowPrimitiveType {}
914950

0 commit comments

Comments
 (0)