diff --git a/datafusion/spark/src/function/math/ceil.rs b/datafusion/spark/src/function/math/ceil.rs new file mode 100644 index 0000000000000..2757ec835a88f --- /dev/null +++ b/datafusion/spark/src/function/math/ceil.rs @@ -0,0 +1,295 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrowNativeTypeOp, AsArray, Decimal128Array}; +use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type, Int64Type}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +/// Spark-compatible `ceil` expression +/// +/// +/// Differences with DataFusion ceil: +/// - Spark's ceil returns Int64 for float inputs; DataFusion preserves +/// the input type (Float32→Float32, Float64→Float64) +/// - Spark's ceil on Decimal128(p, s) returns Decimal128(p−s+1, 0), reducing scale +/// to 0; DataFusion preserves the original precision and scale +/// - Spark only supports Decimal128; DataFusion also supports Decimal32/64/256 +/// - Spark does not check for decimal overflow; DataFusion errors on overflow +/// +/// TODO: 2-argument ceil(value, scale) is not yet implemented +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCeil { + signature: Signature, +} + +impl Default for SparkCeil { + fn default() -> Self { + Self::new() + } +} + +impl SparkCeil { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkCeil { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ceil" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Decimal128(p, s) => { + if *s > 0 { + let new_p = ((*p as i64) - (*s as i64) + 1).clamp(1, 38) as u8; + Ok(DataType::Decimal128(new_p, 0)) + } else { + // scale <= 0 means the value is already a whole number + // (or represents multiples of 10^(-scale)), so ceil is a no-op + Ok(DataType::Decimal128(*p, *s)) + } + } + dt if dt.is_integer() => Ok(dt.clone()), + DataType::Float32 | DataType::Float64 => Ok(DataType::Int64), + other => exec_err!("Unsupported data type {other:?} for function ceil"), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_ceil(&args.args) + } +} + +fn spark_ceil(args: &[ColumnarValue]) -> Result { + let [input] = take_function_args("ceil", args)?; + + match input { + ColumnarValue::Scalar(value) => spark_ceil_scalar(value), + ColumnarValue::Array(input) => spark_ceil_array(input), + } +} + +fn spark_ceil_scalar(value: &ScalarValue) -> Result { + let result = match value { + ScalarValue::Float32(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)), + ScalarValue::Float64(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)), + v if v.data_type().is_integer() => v.clone(), + ScalarValue::Decimal128(v, p, s) if *s > 0 => { + let div = 10_i128.pow_wrapping(*s as u32); + let new_p = ((*p as i64) - (*s as i64) + 1).clamp(1, 38) as u8; + let result = v.map(|x| { + let d = x / div; + let r = x % div; + if r > 0 { d + 1 } else { d } + }); + ScalarValue::Decimal128(result, new_p, 0) + } + ScalarValue::Decimal128(_, _, _) => value.clone(), + other => { + return exec_err!( + "Unsupported data type {:?} for function ceil", + other.data_type() + ); + } + }; + Ok(ColumnarValue::Scalar(result)) +} + +fn spark_ceil_array(input: &Arc) -> Result { + let result = match input.data_type() { + DataType::Float32 => Arc::new( + input + .as_primitive::() + .unary::<_, Int64Type>(|x| x.ceil() as i64), + ) as _, + DataType::Float64 => Arc::new( + input + .as_primitive::() + .unary::<_, Int64Type>(|x| x.ceil() as i64), + ) as _, + dt if dt.is_integer() => Arc::clone(input), + DataType::Decimal128(p, s) if *s > 0 => { + let div = 10_i128.pow_wrapping(*s as u32); + let new_p = ((*p as i64) - (*s as i64) + 1).clamp(1, 38) as u8; + let result: Decimal128Array = + input.as_primitive::().unary(|x| { + let d = x / div; + let r = x % div; + if r > 0 { d + 1 } else { d } + }); + Arc::new(result.with_data_type(DataType::Decimal128(new_p, 0))) + } + DataType::Decimal128(_, _) => Arc::clone(input), + other => return exec_err!("Unsupported data type {other:?} for function ceil"), + }; + + Ok(ColumnarValue::Array(result)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Decimal128Array, Float32Array, Float64Array, Int64Array}; + use datafusion_common::ScalarValue; + + #[test] + fn test_ceil_float64() { + let input = Float64Array::from(vec![ + Some(125.2345), + Some(15.0001), + Some(0.1), + Some(-0.9), + Some(-1.1), + Some(123.0), + None, + ]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!( + result, + &Int64Array::from(vec![ + Some(126), + Some(16), + Some(1), + Some(0), + Some(-1), + Some(123), + None, + ]) + ); + } + + #[test] + fn test_ceil_float32() { + let input = Float32Array::from(vec![ + Some(125.2345f32), + Some(15.0001f32), + Some(0.1f32), + Some(-0.9f32), + Some(-1.1f32), + Some(123.0f32), + None, + ]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!( + result, + &Int64Array::from(vec![ + Some(126), + Some(16), + Some(1), + Some(0), + Some(-1), + Some(123), + None, + ]) + ); + } + + #[test] + fn test_ceil_int64() { + let input = Int64Array::from(vec![Some(1), Some(-1), None]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!(result, &Int64Array::from(vec![Some(1), Some(-1), None])); + } + + #[test] + fn test_ceil_decimal128() { + // Decimal128(10, 2): 150 = 1.50, -150 = -1.50, 100 = 1.00 + let return_type = DataType::Decimal128(9, 0); + let input = Decimal128Array::from(vec![Some(150), Some(-150), Some(100), None]) + .with_data_type(DataType::Decimal128(10, 2)); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + let expected = Decimal128Array::from(vec![Some(2), Some(-1), Some(1), None]) + .with_data_type(return_type); + assert_eq!(result, &expected); + } + + #[test] + fn test_ceil_float64_scalar() { + let input = ScalarValue::Float64(Some(-1.1)); + let args = vec![ColumnarValue::Scalar(input)]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + assert_eq!(result, ScalarValue::Int64(Some(-1))); + } + + #[test] + fn test_ceil_float32_scalar() { + let input = ScalarValue::Float32(Some(125.2345f32)); + let args = vec![ColumnarValue::Scalar(input)]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + assert_eq!(result, ScalarValue::Int64(Some(126))); + } + + #[test] + fn test_ceil_int64_scalar() { + let input = ScalarValue::Int64(Some(48)); + let args = vec![ColumnarValue::Scalar(input)]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + assert_eq!(result, ScalarValue::Int64(Some(48))); + } +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 7f7d04e06b0be..dc2b136b4e91a 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -17,6 +17,7 @@ pub mod abs; pub mod bin; +pub mod ceil; pub mod expm1; pub mod factorial; pub mod hex; @@ -32,6 +33,7 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(abs::SparkAbs, abs); +make_udf_function!(ceil::SparkCeil, ceil); make_udf_function!(expm1::SparkExpm1, expm1); make_udf_function!(factorial::SparkFactorial, factorial); make_udf_function!(hex::SparkHex, hex); @@ -49,6 +51,7 @@ pub mod expr_fn { use datafusion_functions::export_functions; export_functions!((abs, "Returns abs(expr)", arg1)); + export_functions!((ceil, "Returns the ceiling of expr.", arg1)); export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); export_functions!(( factorial, @@ -82,6 +85,7 @@ pub mod expr_fn { pub fn functions() -> Vec> { vec![ abs(), + ceil(), expm1(), factorial(), hex(), diff --git a/datafusion/sqllogictest/test_files/spark/math/ceil.slt b/datafusion/sqllogictest/test_files/spark/math/ceil.slt index c87a29b61fd49..9c8938be5becd 100644 --- a/datafusion/sqllogictest/test_files/spark/math/ceil.slt +++ b/datafusion/sqllogictest/test_files/spark/math/ceil.slt @@ -21,22 +21,138 @@ # For more information, please see: # https://github.com/apache/datafusion/issues/15914 +# Tests for Spark-compatible ceil function. +# Spark semantics differ from DataFusion's built-in ceil in two ways: +# 1. Return type: Spark returns Int64 for float/integer inputs; +# DataFusion returns the same float type (e.g. ceil(1.5::DOUBLE) -> DOUBLE in DF, BIGINT in Spark) +# 2. Decimal precision: Spark adjusts precision to (p - s + 1) for Decimal128(p, s) with scale > 0; +# DataFusion preserves the original precision and scale +# +# Example: SELECT ceil(1.50::DECIMAL(10,2)) +# Spark: returns Decimal(9, 0) value 2 +# DataFusion: returns Decimal(10, 2) value 2.00 + ## Original Query: SELECT ceil(-0.1); ## PySpark 3.5.5 Result: {'CEIL(-0.1)': Decimal('0'), 'typeof(CEIL(-0.1))': 'decimal(1,0)', 'typeof(-0.1)': 'decimal(1,1)'} -#query -#SELECT ceil(-0.1::decimal(1,1)); +query R +SELECT ceil(-0.1::decimal(1,1)); +---- +0 ## Original Query: SELECT ceil(3.1411, -3); ## PySpark 3.5.5 Result: {'ceil(3.1411, -3)': Decimal('1000'), 'typeof(ceil(3.1411, -3))': 'decimal(4,0)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(-3)': 'int'} +## TODO: 2-argument ceil(value, scale) is not yet implemented #query #SELECT ceil(3.1411::decimal(5,4), -3::int); ## Original Query: SELECT ceil(3.1411, 3); ## PySpark 3.5.5 Result: {'ceil(3.1411, 3)': Decimal('3.142'), 'typeof(ceil(3.1411, 3))': 'decimal(5,3)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(3)': 'int'} +## TODO: 2-argument ceil(value, scale) is not yet implemented #query #SELECT ceil(3.1411::decimal(5,4), 3::int); ## Original Query: SELECT ceil(5); ## PySpark 3.5.5 Result: {'CEIL(5)': 5, 'typeof(CEIL(5))': 'bigint', 'typeof(5)': 'int'} -#query -#SELECT ceil(5::int); +query I +SELECT ceil(5::int); +---- +5 + +# Scalar input: float64 returns bigint +query IIIIIII +SELECT ceil(125.2345::DOUBLE), ceil(15.0001::DOUBLE), ceil(0.1::DOUBLE), ceil(-0.9::DOUBLE), ceil(-1.1::DOUBLE), ceil(123.0::DOUBLE), ceil(NULL::DOUBLE); +---- +126 16 1 0 -1 123 NULL + +# Scalar input: float32 returns bigint +query IIIIIII +SELECT ceil(125.2345::FLOAT), ceil(15.0001::FLOAT), ceil(0.1::FLOAT), ceil(-0.9::FLOAT), ceil(-1.1::FLOAT), ceil(123.0::FLOAT), ceil(NULL::FLOAT); +---- +126 16 1 0 -1 123 NULL + +# Scalar input: integer types all return bigint +query III +SELECT ceil(5::TINYINT), ceil(-3::TINYINT), ceil(NULL::TINYINT); +---- +5 -3 NULL + +query III +SELECT ceil(5::SMALLINT), ceil(-3::SMALLINT), ceil(NULL::SMALLINT); +---- +5 -3 NULL + +query III +SELECT ceil(5::INT), ceil(-3::INT), ceil(NULL::INT); +---- +5 -3 NULL + +query III +SELECT ceil(5::BIGINT), ceil(-3::BIGINT), ceil(NULL::BIGINT); +---- +5 -3 NULL + +# Scalar input: decimal128 with scale > 0 returns decimal with scale 0 +# ceil(1.50) = 2, ceil(-1.50) = -1, ceil(1.00) = 1 +query RRR +SELECT ceil(1.50::DECIMAL(10, 2)), ceil(-1.50::DECIMAL(10, 2)), ceil(1.00::DECIMAL(10, 2)); +---- +2 -1 1 + +# ceil(-0.1) = 0 (smallest positive decimal rounds up to 0 for negatives) +query RR +SELECT ceil(-0.1::DECIMAL(3, 1)), ceil(NULL::DECIMAL(10, 2)); +---- +0 NULL + +# ceil(3.1411) = 4 +query R +SELECT ceil(3.1411::DECIMAL(5, 4)); +---- +4 + +# Scalar input: decimal128 with scale = 0 passes through unchanged +query RRR +SELECT ceil(5::DECIMAL(10, 0)), ceil(-3::DECIMAL(10, 0)), ceil(NULL::DECIMAL(10, 0)); +---- +5 -3 NULL + +# Array input: float64 +query I +SELECT ceil(a) FROM (VALUES (125.2345::DOUBLE), (15.0001::DOUBLE), (0.1::DOUBLE), (-0.9::DOUBLE), (-1.1::DOUBLE), (123.0::DOUBLE), (NULL::DOUBLE)) AS t(a); +---- +126 +16 +1 +0 +-1 +123 +NULL + +# Array input: float32 +query I +SELECT ceil(a) FROM (VALUES (125.2345::FLOAT), (15.0001::FLOAT), (0.1::FLOAT), (-0.9::FLOAT), (-1.1::FLOAT), (123.0::FLOAT), (NULL::FLOAT)) AS t(a); +---- +126 +16 +1 +0 +-1 +123 +NULL + +# Array input: integers +query I +SELECT ceil(a) FROM (VALUES (5::INT), (-3::INT), (NULL::INT)) AS t(a); +---- +5 +-3 +NULL + +# Array input: decimal128 with scale > 0 +query R +SELECT ceil(a) FROM (VALUES (1.50::DECIMAL(10, 2)), (-1.50::DECIMAL(10, 2)), (1.00::DECIMAL(10, 2)), (NULL::DECIMAL(10, 2))) AS t(a); +---- +2 +-1 +1 +NULL diff --git a/datafusion/sqllogictest/test_files/spark/math/ceiling.slt b/datafusion/sqllogictest/test_files/spark/math/ceiling.slt deleted file mode 100644 index 2b761faef47df..0000000000000 --- a/datafusion/sqllogictest/test_files/spark/math/ceiling.slt +++ /dev/null @@ -1,42 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT ceiling(-0.1); -## PySpark 3.5.5 Result: {'ceiling(-0.1)': Decimal('0'), 'typeof(ceiling(-0.1))': 'decimal(1,0)', 'typeof(-0.1)': 'decimal(1,1)'} -#query -#SELECT ceiling(-0.1::decimal(1,1)); - -## Original Query: SELECT ceiling(3.1411, -3); -## PySpark 3.5.5 Result: {'ceiling(3.1411, -3)': Decimal('1000'), 'typeof(ceiling(3.1411, -3))': 'decimal(4,0)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(-3)': 'int'} -#query -#SELECT ceiling(3.1411::decimal(5,4), -3::int); - -## Original Query: SELECT ceiling(3.1411, 3); -## PySpark 3.5.5 Result: {'ceiling(3.1411, 3)': Decimal('3.142'), 'typeof(ceiling(3.1411, 3))': 'decimal(5,3)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(3)': 'int'} -#query -#SELECT ceiling(3.1411::decimal(5,4), 3::int); - -## Original Query: SELECT ceiling(5); -## PySpark 3.5.5 Result: {'ceiling(5)': 5, 'typeof(ceiling(5))': 'bigint', 'typeof(5)': 'int'} -#query -#SELECT ceiling(5::int);