-
Notifications
You must be signed in to change notification settings - Fork 2k
[datafusion-spark] Add Spark-compatible ceil function #20593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c9f25b2
35adebd
fb82ec0
514e64b
f0d428c
9aad479
0ad51f9
6640dc2
bb81de7
de00eed
9e7a0fd
ee99235
ab37090
14b6530
f3069d1
ce9f345
336870f
3dce8df
f2c6020
8d4c63f
2d3658f
fd23001
1b04c3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,295 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| use std::any::Any; | ||
| use std::sync::Arc; | ||
|
|
||
| use arrow::array::{ArrowNativeTypeOp, AsArray, Decimal128Array}; | ||
| use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type, Int64Type}; | ||
| use datafusion_common::utils::take_function_args; | ||
| use datafusion_common::{Result, ScalarValue, exec_err}; | ||
| use datafusion_expr::{ | ||
| ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, | ||
| }; | ||
|
|
||
| /// Spark-compatible `ceil` expression | ||
| /// <https://spark.apache.org/docs/latest/api/sql/index.html#ceil> | ||
| /// | ||
| /// Differences with DataFusion ceil: | ||
| /// - Spark's ceil returns Int64 for float inputs; DataFusion preserves | ||
| /// the input type (Float32→Float32, Float64→Float64) | ||
| /// - Spark's ceil on Decimal128(p, s) returns Decimal128(p−s+1, 0), reducing scale | ||
| /// to 0; DataFusion preserves the original precision and scale | ||
| /// - Spark only supports Decimal128; DataFusion also supports Decimal32/64/256 | ||
| /// - Spark does not check for decimal overflow; DataFusion errors on overflow | ||
| /// | ||
| /// TODO: 2-argument ceil(value, scale) is not yet implemented | ||
| #[derive(Debug, PartialEq, Eq, Hash)] | ||
| pub struct SparkCeil { | ||
| signature: Signature, | ||
| } | ||
|
|
||
| impl Default for SparkCeil { | ||
| fn default() -> Self { | ||
| Self::new() | ||
| } | ||
| } | ||
|
|
||
| impl SparkCeil { | ||
| pub fn new() -> Self { | ||
| Self { | ||
| signature: Signature::numeric(1, Volatility::Immutable), | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl ScalarUDFImpl for SparkCeil { | ||
| fn as_any(&self) -> &dyn Any { | ||
| self | ||
| } | ||
|
|
||
| fn name(&self) -> &str { | ||
| "ceil" | ||
| } | ||
|
|
||
| fn signature(&self) -> &Signature { | ||
| &self.signature | ||
| } | ||
|
|
||
| fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
| match &arg_types[0] { | ||
| DataType::Decimal128(p, s) => { | ||
| if *s > 0 { | ||
| let new_p = ((*p as i64) - (*s as i64) + 1).clamp(1, 38) as u8; | ||
| Ok(DataType::Decimal128(new_p, 0)) | ||
| } else { | ||
| // scale <= 0 means the value is already a whole number | ||
| // (or represents multiples of 10^(-scale)), so ceil is a no-op | ||
| Ok(DataType::Decimal128(*p, *s)) | ||
| } | ||
| } | ||
| dt if dt.is_integer() => Ok(dt.clone()), | ||
| DataType::Float32 | DataType::Float64 => Ok(DataType::Int64), | ||
| other => exec_err!("Unsupported data type {other:?} for function ceil"), | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also recommend using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this UDF I don't think it's strictly necessary, I believe we can infer the output type from |
||
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| spark_ceil(&args.args) | ||
| } | ||
| } | ||
|
|
||
| fn spark_ceil(args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
| let [input] = take_function_args("ceil", args)?; | ||
|
|
||
| match input { | ||
| ColumnarValue::Scalar(value) => spark_ceil_scalar(value), | ||
| ColumnarValue::Array(input) => spark_ceil_array(input), | ||
| } | ||
| } | ||
|
|
||
| fn spark_ceil_scalar(value: &ScalarValue) -> Result<ColumnarValue> { | ||
| let result = match value { | ||
| ScalarValue::Float32(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)), | ||
| ScalarValue::Float64(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)), | ||
| v if v.data_type().is_integer() => v.clone(), | ||
| ScalarValue::Decimal128(v, p, s) if *s > 0 => { | ||
| let div = 10_i128.pow_wrapping(*s as u32); | ||
| let new_p = ((*p as i64) - (*s as i64) + 1).clamp(1, 38) as u8; | ||
| let result = v.map(|x| { | ||
| let d = x / div; | ||
| let r = x % div; | ||
| if r > 0 { d + 1 } else { d } | ||
| }); | ||
| ScalarValue::Decimal128(result, new_p, 0) | ||
| } | ||
| ScalarValue::Decimal128(_, _, _) => value.clone(), | ||
| other => { | ||
| return exec_err!( | ||
| "Unsupported data type {:?} for function ceil", | ||
| other.data_type() | ||
| ); | ||
| } | ||
| }; | ||
| Ok(ColumnarValue::Scalar(result)) | ||
| } | ||
|
|
||
| fn spark_ceil_array(input: &Arc<dyn arrow::array::Array>) -> Result<ColumnarValue> { | ||
| let result = match input.data_type() { | ||
| DataType::Float32 => Arc::new( | ||
| input | ||
| .as_primitive::<Float32Type>() | ||
| .unary::<_, Int64Type>(|x| x.ceil() as i64), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could add an inline function .unary::<_, Int64Type>(|x| x.ceil() as i64) for both float inputs so that we dont repeat ourselves |
||
| ) as _, | ||
| DataType::Float64 => Arc::new( | ||
| input | ||
| .as_primitive::<Float64Type>() | ||
| .unary::<_, Int64Type>(|x| x.ceil() as i64), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could add an inline function
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried this and couldn't figure out a way to make it cleaner than what we have now, I'd prefer to keep it as is unless there's a solution I'm missing which is very possible |
||
| ) as _, | ||
| dt if dt.is_integer() => Arc::clone(input), | ||
| DataType::Decimal128(p, s) if *s > 0 => { | ||
| let div = 10_i128.pow_wrapping(*s as u32); | ||
| let new_p = ((*p as i64) - (*s as i64) + 1).clamp(1, 38) as u8; | ||
| let result: Decimal128Array = | ||
| input.as_primitive::<Decimal128Type>().unary(|x| { | ||
| let d = x / div; | ||
| let r = x % div; | ||
| if r > 0 { d + 1 } else { d } | ||
| }); | ||
| Arc::new(result.with_data_type(DataType::Decimal128(new_p, 0))) | ||
| } | ||
| DataType::Decimal128(_, _) => Arc::clone(input), | ||
| other => return exec_err!("Unsupported data type {other:?} for function ceil"), | ||
| }; | ||
|
|
||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use arrow::array::{Decimal128Array, Float32Array, Float64Array, Int64Array}; | ||
| use datafusion_common::ScalarValue; | ||
|
|
||
| #[test] | ||
| fn test_ceil_float64() { | ||
| let input = Float64Array::from(vec![ | ||
| Some(125.2345), | ||
| Some(15.0001), | ||
| Some(0.1), | ||
| Some(-0.9), | ||
| Some(-1.1), | ||
| Some(123.0), | ||
| None, | ||
| ]); | ||
| let args = vec![ColumnarValue::Array(Arc::new(input))]; | ||
| let result = spark_ceil(&args).unwrap(); | ||
| let result = match result { | ||
| ColumnarValue::Array(arr) => arr, | ||
| _ => panic!("Expected array"), | ||
| }; | ||
| let result = result.as_primitive::<Int64Type>(); | ||
| assert_eq!( | ||
| result, | ||
| &Int64Array::from(vec![ | ||
| Some(126), | ||
| Some(16), | ||
| Some(1), | ||
| Some(0), | ||
| Some(-1), | ||
| Some(123), | ||
| None, | ||
| ]) | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_ceil_float32() { | ||
| let input = Float32Array::from(vec![ | ||
| Some(125.2345f32), | ||
| Some(15.0001f32), | ||
| Some(0.1f32), | ||
| Some(-0.9f32), | ||
| Some(-1.1f32), | ||
| Some(123.0f32), | ||
| None, | ||
| ]); | ||
| let args = vec![ColumnarValue::Array(Arc::new(input))]; | ||
| let result = spark_ceil(&args).unwrap(); | ||
| let result = match result { | ||
| ColumnarValue::Array(arr) => arr, | ||
| _ => panic!("Expected array"), | ||
| }; | ||
| let result = result.as_primitive::<Int64Type>(); | ||
| assert_eq!( | ||
| result, | ||
| &Int64Array::from(vec![ | ||
| Some(126), | ||
| Some(16), | ||
| Some(1), | ||
| Some(0), | ||
| Some(-1), | ||
| Some(123), | ||
| None, | ||
| ]) | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_ceil_int64() { | ||
| let input = Int64Array::from(vec![Some(1), Some(-1), None]); | ||
| let args = vec![ColumnarValue::Array(Arc::new(input))]; | ||
| let result = spark_ceil(&args).unwrap(); | ||
| let result = match result { | ||
| ColumnarValue::Array(arr) => arr, | ||
| _ => panic!("Expected array"), | ||
| }; | ||
| let result = result.as_primitive::<Int64Type>(); | ||
| assert_eq!(result, &Int64Array::from(vec![Some(1), Some(-1), None])); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_ceil_decimal128() { | ||
| // Decimal128(10, 2): 150 = 1.50, -150 = -1.50, 100 = 1.00 | ||
| let return_type = DataType::Decimal128(9, 0); | ||
| let input = Decimal128Array::from(vec![Some(150), Some(-150), Some(100), None]) | ||
| .with_data_type(DataType::Decimal128(10, 2)); | ||
| let args = vec![ColumnarValue::Array(Arc::new(input))]; | ||
| let result = spark_ceil(&args).unwrap(); | ||
| let result = match result { | ||
| ColumnarValue::Array(arr) => arr, | ||
| _ => panic!("Expected array"), | ||
| }; | ||
| let result = result.as_primitive::<Decimal128Type>(); | ||
| let expected = Decimal128Array::from(vec![Some(2), Some(-1), Some(1), None]) | ||
| .with_data_type(return_type); | ||
| assert_eq!(result, &expected); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_ceil_float64_scalar() { | ||
| let input = ScalarValue::Float64(Some(-1.1)); | ||
| let args = vec![ColumnarValue::Scalar(input)]; | ||
| let result = match spark_ceil(&args).unwrap() { | ||
| ColumnarValue::Scalar(v) => v, | ||
| _ => panic!("Expected scalar"), | ||
| }; | ||
| assert_eq!(result, ScalarValue::Int64(Some(-1))); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_ceil_float32_scalar() { | ||
| let input = ScalarValue::Float32(Some(125.2345f32)); | ||
| let args = vec![ColumnarValue::Scalar(input)]; | ||
| let result = match spark_ceil(&args).unwrap() { | ||
| ColumnarValue::Scalar(v) => v, | ||
| _ => panic!("Expected scalar"), | ||
| }; | ||
| assert_eq!(result, ScalarValue::Int64(Some(126))); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_ceil_int64_scalar() { | ||
| let input = ScalarValue::Int64(Some(48)); | ||
| let args = vec![ColumnarValue::Scalar(input)]; | ||
| let result = match spark_ceil(&args).unwrap() { | ||
| ColumnarValue::Scalar(v) => v, | ||
| _ => panic!("Expected scalar"), | ||
| }; | ||
| assert_eq!(result, ScalarValue::Int64(Some(48))); | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.