@@ -30,61 +30,80 @@ use num::{One, Zero};
3030use crate :: buffer:: Buffer ;
3131#[ cfg( simd) ]
3232use crate :: buffer:: MutableBuffer ;
33- use crate :: compute:: util:: combine_option_bitmap;
33+ use crate :: compute:: { kernels :: arity :: unary , util:: combine_option_bitmap} ;
3434use crate :: datatypes;
3535use crate :: datatypes:: ArrowNumericType ;
3636use crate :: error:: { ArrowError , Result } ;
3737use crate :: { array:: * , util:: bit_util} ;
38+ use num:: traits:: Pow ;
3839#[ cfg( simd) ]
3940use std:: borrow:: BorrowMut ;
4041#[ cfg( simd) ]
4142use std:: slice:: { ChunksExact , ChunksExactMut } ;
4243
43- /// Helper function to perform math lambda function on values from single array of signed numeric
44- /// type. If value is null then the output value is also null, so `-null` is `null`.
45- pub fn signed_unary_math_op < T , F > (
44+ /// SIMD vectorized version of `unary_math_op` above specialized for signed numerical values.
45+ # [ cfg ( simd ) ]
46+ fn simd_signed_unary_math_op < T , SIMD_OP , SCALAR_OP > (
4647 array : & PrimitiveArray < T > ,
47- op : F ,
48+ simd_op : SIMD_OP ,
49+ scalar_op : SCALAR_OP ,
4850) -> Result < PrimitiveArray < T > >
4951where
5052 T : datatypes:: ArrowSignedNumericType ,
51- T :: Native : Neg < Output = T :: Native > ,
52- F : Fn ( T :: Native ) -> T :: Native ,
53+ SIMD_OP : Fn ( T :: SignedSimd ) -> T :: SignedSimd ,
54+ SCALAR_OP : Fn ( T :: Native ) -> T :: Native ,
5355{
54- let values = array. values ( ) . iter ( ) . map ( |v| op ( * v) ) ;
55- // JUSTIFICATION
56- // Benefit
57- // ~60% speedup
58- // Soundness
59- // `values` is an iterator with a known size.
60- let buffer = unsafe { Buffer :: from_trusted_len_iter ( values) } ;
56+ let lanes = T :: lanes ( ) ;
57+ let buffer_size = array. len ( ) * std:: mem:: size_of :: < T :: Native > ( ) ;
58+ let mut result = MutableBuffer :: new ( buffer_size) . with_bitset ( buffer_size, false ) ;
59+
60+ let mut result_chunks = result. typed_data_mut ( ) . chunks_exact_mut ( lanes) ;
61+ let mut array_chunks = array. values ( ) . chunks_exact ( lanes) ;
62+
63+ result_chunks
64+ . borrow_mut ( )
65+ . zip ( array_chunks. borrow_mut ( ) )
66+ . for_each ( |( result_slice, input_slice) | {
67+ let simd_input = T :: load_signed ( input_slice) ;
68+ let simd_result = T :: signed_unary_op ( simd_input, & simd_op) ;
69+ T :: write_signed ( simd_result, result_slice) ;
70+ } ) ;
71+
72+ let result_remainder = result_chunks. into_remainder ( ) ;
73+ let array_remainder = array_chunks. remainder ( ) ;
74+
75+ result_remainder. into_iter ( ) . zip ( array_remainder) . for_each (
76+ |( scalar_result, scalar_input) | {
77+ * scalar_result = scalar_op ( * scalar_input) ;
78+ } ,
79+ ) ;
6180
6281 let data = ArrayData :: new (
6382 T :: DATA_TYPE ,
6483 array. len ( ) ,
6584 None ,
6685 array. data_ref ( ) . null_buffer ( ) . cloned ( ) ,
6786 0 ,
68- vec ! [ buffer ] ,
87+ vec ! [ result . into ( ) ] ,
6988 vec ! [ ] ,
7089 ) ;
7190 Ok ( PrimitiveArray :: < T > :: from ( Arc :: new ( data) ) )
7291}
7392
74- /// SIMD vectorized version of `signed_unary_math_op` above.
7593#[ cfg( simd) ]
76- fn simd_signed_unary_math_op < T , SIMD_OP , SCALAR_OP > (
94+ fn simd_float_unary_math_op < T , SIMD_OP , SCALAR_OP > (
7795 array : & PrimitiveArray < T > ,
7896 simd_op : SIMD_OP ,
7997 scalar_op : SCALAR_OP ,
8098) -> Result < PrimitiveArray < T > >
8199where
82- T : datatypes:: ArrowSignedNumericType ,
83- SIMD_OP : Fn ( T :: SignedSimd ) -> T :: SignedSimd ,
100+ T : datatypes:: ArrowFloatNumericType ,
101+ SIMD_OP : Fn ( T :: Simd ) -> T :: Simd ,
84102 SCALAR_OP : Fn ( T :: Native ) -> T :: Native ,
85103{
86104 let lanes = T :: lanes ( ) ;
87105 let buffer_size = array. len ( ) * std:: mem:: size_of :: < T :: Native > ( ) ;
106+
88107 let mut result = MutableBuffer :: new ( buffer_size) . with_bitset ( buffer_size, false ) ;
89108
90109 let mut result_chunks = result. typed_data_mut ( ) . chunks_exact_mut ( lanes) ;
94113 . borrow_mut ( )
95114 . zip ( array_chunks. borrow_mut ( ) )
96115 . for_each ( |( result_slice, input_slice) | {
97- let simd_input = T :: load_signed ( input_slice) ;
98- let simd_result = T :: signed_unary_op ( simd_input, & simd_op) ;
99- T :: write_signed ( simd_result, result_slice) ;
116+ let simd_input = T :: load ( input_slice) ;
117+ let simd_result = T :: unary_op ( simd_input, & simd_op) ;
118+ T :: write ( simd_result, result_slice) ;
100119 } ) ;
101120
102121 let result_remainder = result_chunks. into_remainder ( ) ;
@@ -536,7 +555,29 @@ where
536555 #[ cfg( simd) ]
537556 return simd_signed_unary_math_op ( array, |x| -x, |x| -x) ;
538557 #[ cfg( not( simd) ) ]
539- return signed_unary_math_op ( array, |x| -x) ;
558+ return Ok ( unary ( array, |x| -x) ) ;
559+ }
560+
561+ /// Raise array with floating point values to the power of a scalar.
562+ pub fn powf_scalar < T > (
563+ array : & PrimitiveArray < T > ,
564+ raise : T :: Native ,
565+ ) -> Result < PrimitiveArray < T > >
566+ where
567+ T : datatypes:: ArrowFloatNumericType ,
568+ T :: Native : Pow < T :: Native , Output = T :: Native > ,
569+ {
570+ #[ cfg( simd) ]
571+ {
572+ let raise_vector = T :: init ( raise) ;
573+ return simd_float_unary_math_op (
574+ array,
575+ |x| T :: pow ( x, raise_vector) ,
576+ |x| x. pow ( raise) ,
577+ ) ;
578+ }
579+ #[ cfg( not( simd) ) ]
580+ return Ok ( unary ( array, |x| x. pow ( raise) ) ) ;
540581}
541582
542583/// Perform `left * right` operation on two arrays. If either left or right value is null
@@ -808,4 +849,16 @@ mod tests {
808849 . collect ( ) ;
809850 assert_eq ! ( expected, actual) ;
810851 }
852+
853+ #[ test]
854+ fn test_primitive_array_raise_power_scalar ( ) {
855+ let a = Float64Array :: from ( vec ! [ 1.0 , 2.0 , 3.0 ] ) ;
856+ let actual = powf_scalar ( & a, 2.0 ) . unwrap ( ) ;
857+ let expected = Float64Array :: from ( vec ! [ 1.0 , 4.0 , 9.0 ] ) ;
858+ assert_eq ! ( expected, actual) ;
859+ let a = Float64Array :: from ( vec ! [ Some ( 1.0 ) , None , Some ( 3.0 ) ] ) ;
860+ let actual = powf_scalar ( & a, 2.0 ) . unwrap ( ) ;
861+ let expected = Float64Array :: from ( vec ! [ Some ( 1.0 ) , None , Some ( 9.0 ) ] ) ;
862+ assert_eq ! ( expected, actual) ;
863+ }
811864}
0 commit comments