diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index b79b43f6c9..15bbabe883 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -126,8 +126,9 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId; use datafusion_comet_spark_expr::{ ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, - GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RandExpr, - RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, + DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, + NormalizeNaNAndZero, RandExpr, RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, + UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::GlobalRef; @@ -408,10 +409,45 @@ impl PhysicalPlanner { ))) } ExprStruct::CheckOverflow(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; + let child = + self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let fail_on_error = expr.fail_on_error; + // WideDecimalBinaryExpr already handles overflow — skip redundant check + // but only if its output type matches CheckOverflow's declared type + if child + .as_any() + .downcast_ref::() + .is_some() + { + let child_type = child.data_type(&input_schema)?; + if child_type == data_type { + return Ok(child); + } + } + + // Fuse Cast(Decimal128→Decimal128) + CheckOverflow into single rescale+check + // Only fuse when the Cast target type matches the CheckOverflow output type + if let Some(cast) = child.as_any().downcast_ref::() { + if let ( + DataType::Decimal128(p_out, s_out), + Ok(DataType::Decimal128(_p_in, s_in)), + ) = (&data_type, cast.child.data_type(&input_schema)) + { + let cast_target = cast.data_type(&input_schema)?; + if cast_target == data_type { + return Ok(Arc::new(DecimalRescaleCheckOverflow::new( + Arc::clone(&cast.child), + s_in, + *p_out, + *s_out, + fail_on_error, + ))); + } + } + } + // Look up query context from registry if expr_id is present let query_context = spark_expr.expr_id.and_then(|expr_id| { let registry = &self.query_context_registry; @@ -740,29 +776,22 @@ impl PhysicalPlanner { || (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) => { let data_type = return_type.map(to_arrow_datatype).unwrap(); - // For some Decimal128 operations, we need wider internal digits. - // Cast left and right to Decimal256 and cast the result back to Decimal128 - let left = Arc::new(Cast::new( - left, - DataType::Decimal256(p1, s1), - SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), - None, - None, - )); - let right = Arc::new(Cast::new( - right, - DataType::Decimal256(p2, s2), - SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), - None, - None, - )); - let child = Arc::new(BinaryExpr::new(left, op, right)); - Ok(Arc::new(Cast::new( - child, - data_type, - SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), - None, - None, + let (p_out, s_out) = match &data_type { + DataType::Decimal128(p, s) => (*p, *s), + dt => { + return Err(ExecutionError::GeneralError(format!( + "Expected Decimal128 return type, got {dt:?}" + ))) + } + }; + let wide_op = match op { + DataFusionOperator::Plus => WideDecimalOp::Add, + DataFusionOperator::Minus => WideDecimalOp::Subtract, + DataFusionOperator::Multiply => WideDecimalOp::Multiply, + _ => unreachable!(), + }; + Ok(Arc::new(WideDecimalBinaryExpr::new( + left, right, wide_op, p_out, s_out, eval_mode, ))) } ( diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 9f08e480f2..d4639c86ea 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -105,6 +105,10 @@ path = "tests/spark_expr_reg.rs" name = "cast_from_boolean" harness = false +[[bench]] +name = "wide_decimal" +harness = false + [[bench]] name = "cast_non_int_numeric_timestamp" harness = false diff --git a/native/spark-expr/benches/wide_decimal.rs b/native/spark-expr/benches/wide_decimal.rs new file mode 100644 index 0000000000..ec932ae68f --- /dev/null +++ b/native/spark-expr/benches/wide_decimal.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks comparing the old Cast->BinaryExpr->Cast chain vs the fused WideDecimalBinaryExpr +//! for Decimal128 arithmetic that requires wider intermediate precision. + +use arrow::array::builder::Decimal128Builder; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion::logical_expr::Operator; +use datafusion::physical_expr::expressions::{BinaryExpr, Column}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_spark_expr::{ + Cast, EvalMode, SparkCastOptions, WideDecimalBinaryExpr, WideDecimalOp, +}; +use std::sync::Arc; + +const BATCH_SIZE: usize = 8192; + +/// Build a RecordBatch with two Decimal128 columns. +fn make_decimal_batch(p1: u8, s1: i8, p2: u8, s2: i8) -> RecordBatch { + let mut left = Decimal128Builder::new(); + let mut right = Decimal128Builder::new(); + for i in 0..BATCH_SIZE as i128 { + left.append_value(123456789012345_i128 + i * 1000); + right.append_value(987654321098765_i128 - i * 1000); + } + let left = left.finish().with_data_type(DataType::Decimal128(p1, s1)); + let right = right.finish().with_data_type(DataType::Decimal128(p2, s2)); + let schema = Schema::new(vec![ + Field::new("left", DataType::Decimal128(p1, s1), false), + Field::new("right", DataType::Decimal128(p2, s2), false), + ]); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(left), Arc::new(right)]).unwrap() +} + +/// Old approach: Cast(Decimal128->Decimal256) both sides, BinaryExpr, Cast(Decimal256->Decimal128). +fn build_old_expr( + p1: u8, + s1: i8, + p2: u8, + s2: i8, + op: Operator, + out_type: DataType, +) -> Arc { + let left_col: Arc = Arc::new(Column::new("left", 0)); + let right_col: Arc = Arc::new(Column::new("right", 1)); + let cast_opts = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false); + let left_cast = Arc::new(Cast::new( + left_col, + DataType::Decimal256(p1, s1), + cast_opts.clone(), + None, + None, + )); + let right_cast = Arc::new(Cast::new( + right_col, + DataType::Decimal256(p2, s2), + cast_opts.clone(), + None, + None, + )); + let binary = Arc::new(BinaryExpr::new(left_cast, op, right_cast)); + Arc::new(Cast::new(binary, out_type, cast_opts, None, None)) +} + +/// New approach: single fused WideDecimalBinaryExpr. +fn build_new_expr(op: WideDecimalOp, p_out: u8, s_out: i8) -> Arc { + let left_col: Arc = Arc::new(Column::new("left", 0)); + let right_col: Arc = Arc::new(Column::new("right", 1)); + Arc::new(WideDecimalBinaryExpr::new( + left_col, + right_col, + op, + p_out, + s_out, + EvalMode::Legacy, + )) +} + +fn bench_case( + group: &mut criterion::BenchmarkGroup, + name: &str, + batch: &RecordBatch, + old_expr: &Arc, + new_expr: &Arc, +) { + group.bench_with_input(BenchmarkId::new("old", name), batch, |b, batch| { + b.iter(|| old_expr.evaluate(batch).unwrap()); + }); + group.bench_with_input(BenchmarkId::new("fused", name), batch, |b, batch| { + b.iter(|| new_expr.evaluate(batch).unwrap()); + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("wide_decimal"); + + // Case 1: Add with same scale - Decimal128(38,10) + Decimal128(38,10) -> Decimal128(38,10) + // Triggers wide path because max(s1,s2) + max(p1-s1, p2-s2) = 10 + 28 = 38 >= 38 + { + let batch = make_decimal_batch(38, 10, 38, 10); + let old = build_old_expr(38, 10, 38, 10, Operator::Plus, DataType::Decimal128(38, 10)); + let new = build_new_expr(WideDecimalOp::Add, 38, 10); + bench_case(&mut group, "add_same_scale", &batch, &old, &new); + } + + // Case 2: Add with different scales - Decimal128(38,6) + Decimal128(38,4) -> Decimal128(38,6) + { + let batch = make_decimal_batch(38, 6, 38, 4); + let old = build_old_expr(38, 6, 38, 4, Operator::Plus, DataType::Decimal128(38, 6)); + let new = build_new_expr(WideDecimalOp::Add, 38, 6); + bench_case(&mut group, "add_diff_scale", &batch, &old, &new); + } + + // Case 3: Multiply - Decimal128(20,10) * Decimal128(20,10) -> Decimal128(38,6) + // Triggers wide path because p1 + p2 = 40 >= 38 + { + let batch = make_decimal_batch(20, 10, 20, 10); + let old = build_old_expr( + 20, + 10, + 20, + 10, + Operator::Multiply, + DataType::Decimal128(38, 6), + ); + let new = build_new_expr(WideDecimalOp::Multiply, 38, 6); + bench_case(&mut group, "multiply", &batch, &old, &new); + } + + // Case 4: Subtract with same scale - Decimal128(38,18) - Decimal128(38,18) -> Decimal128(38,18) + { + let batch = make_decimal_batch(38, 18, 38, 18); + let old = build_old_expr( + 38, + 18, + 38, + 18, + Operator::Minus, + DataType::Decimal128(38, 18), + ); + let new = build_new_expr(WideDecimalOp::Subtract, 38, 18); + bench_case(&mut group, "subtract", &batch, &old, &new); + } + + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 072fa1fad7..ba19d6a9b2 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -80,7 +80,8 @@ pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, - spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero, + spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr, + NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp, }; pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; pub use string_funcs::*; diff --git a/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs new file mode 100644 index 0000000000..1322404951 --- /dev/null +++ b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs @@ -0,0 +1,482 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fused decimal rescale + overflow check expression. +//! +//! Replaces the pattern `CheckOverflow(Cast(expr, Decimal128(p2,s2)), Decimal128(p2,s2))` +//! with a single expression that rescales and validates precision in one pass. + +use arrow::array::{as_primitive_array, Array, ArrayRef, Decimal128Array}; +use arrow::datatypes::{DataType, Decimal128Type, Schema}; +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; +use datafusion::common::{DataFusionError, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Display, Formatter}, + sync::Arc, +}; + +/// A fused expression that rescales a Decimal128 value (changing scale) and checks +/// for precision overflow in a single pass. Replaces the two-step +/// `CheckOverflow(Cast(expr, Decimal128(p,s)))` pattern. +#[derive(Debug, Eq)] +pub struct DecimalRescaleCheckOverflow { + child: Arc, + input_scale: i8, + output_precision: u8, + output_scale: i8, + fail_on_error: bool, +} + +impl Hash for DecimalRescaleCheckOverflow { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.input_scale.hash(state); + self.output_precision.hash(state); + self.output_scale.hash(state); + self.fail_on_error.hash(state); + } +} + +impl PartialEq for DecimalRescaleCheckOverflow { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.input_scale == other.input_scale + && self.output_precision == other.output_precision + && self.output_scale == other.output_scale + && self.fail_on_error == other.fail_on_error + } +} + +impl DecimalRescaleCheckOverflow { + pub fn new( + child: Arc, + input_scale: i8, + output_precision: u8, + output_scale: i8, + fail_on_error: bool, + ) -> Self { + Self { + child, + input_scale, + output_precision, + output_scale, + fail_on_error, + } + } +} + +impl Display for DecimalRescaleCheckOverflow { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "DecimalRescaleCheckOverflow [child: {}, input_scale: {}, output: Decimal128({}, {}), fail_on_error: {}]", + self.child, self.input_scale, self.output_precision, self.output_scale, self.fail_on_error + ) + } +} + +/// Maximum absolute value for a given decimal precision: 10^p - 1. +/// Precision must be <= 38 (max for Decimal128). +#[inline] +fn precision_bound(precision: u8) -> i128 { + assert!( + precision <= 38, + "precision_bound: precision {precision} exceeds maximum 38" + ); + 10i128.pow(precision as u32) - 1 +} + +/// Rescale a single i128 value by the given delta (output_scale - input_scale) +/// and check precision bounds. Returns `Ok(value)` or `Ok(i128::MAX)` as sentinel +/// for overflow in legacy mode, or `Err` in ANSI mode. +#[inline] +fn rescale_and_check( + value: i128, + delta: i8, + scale_factor: i128, + bound: i128, + fail_on_error: bool, +) -> Result { + let rescaled = if delta > 0 { + // Scale up: multiply. Check for overflow. + match value.checked_mul(scale_factor) { + Some(v) => v, + None => { + if fail_on_error { + return Err(ArrowError::ComputeError( + "Decimal overflow during rescale".to_string(), + )); + } + return Ok(i128::MAX); // sentinel + } + } + } else if delta < 0 { + // Scale down with HALF_UP rounding + // divisor = 10^(-delta), half = divisor / 2 + let divisor = scale_factor; // already 10^abs(delta) + let half = divisor / 2; + let sign = value.signum(); + (value + sign * half) / divisor + } else { + value + }; + + // Precision check + if rescaled.abs() > bound { + if fail_on_error { + return Err(ArrowError::ComputeError( + "Decimal overflow: value does not fit in precision".to_string(), + )); + } + Ok(i128::MAX) // sentinel for null_if_overflow_precision + } else { + Ok(rescaled) + } +} + +impl PhysicalExpr for DecimalRescaleCheckOverflow { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _: &Schema) -> datafusion::common::Result { + Ok(DataType::Decimal128( + self.output_precision, + self.output_scale, + )) + } + + fn nullable(&self, _: &Schema) -> datafusion::common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result { + let arg = self.child.evaluate(batch)?; + let delta = self.output_scale - self.input_scale; + let abs_delta = delta.unsigned_abs(); + // If abs_delta > 38, the scale factor overflows i128. In that case, + // any non-zero value will overflow the output precision, so we treat + // it as an immediate overflow condition. + if abs_delta > 38 { + return Err(DataFusionError::Execution(format!( + "DecimalRescaleCheckOverflow: scale delta {delta} exceeds maximum supported range" + ))); + } + let scale_factor = 10i128.pow(abs_delta as u32); + let bound = precision_bound(self.output_precision); + let fail_on_error = self.fail_on_error; + let p_out = self.output_precision; + let s_out = self.output_scale; + + match arg { + ColumnarValue::Array(array) + if matches!(array.data_type(), DataType::Decimal128(_, _)) => + { + let decimal_array = as_primitive_array::(&array); + + let result: Decimal128Array = + arrow::compute::kernels::arity::try_unary(decimal_array, |value| { + rescale_and_check(value, delta, scale_factor, bound, fail_on_error) + })?; + + let result = if !fail_on_error { + result.null_if_overflow_precision(p_out) + } else { + result + }; + + let result = result + .with_precision_and_scale(p_out, s_out) + .map(|a| Arc::new(a) as ArrayRef)?; + + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Decimal128(v, _precision, _scale)) => { + let new_v = match v { + Some(val) => { + let r = rescale_and_check(val, delta, scale_factor, bound, fail_on_error) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + if r == i128::MAX { + None + } else { + Some(r) + } + } + None => None, + }; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + new_v, p_out, s_out, + ))) + } + v => Err(DataFusionError::Execution(format!( + "DecimalRescaleCheckOverflow expects Decimal128, but found {v:?}" + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::common::Result> { + if children.len() != 1 { + return Err(DataFusionError::Internal(format!( + "DecimalRescaleCheckOverflow expects 1 child, got {}", + children.len() + ))); + } + Ok(Arc::new(DecimalRescaleCheckOverflow::new( + Arc::clone(&children[0]), + self.input_scale, + self.output_precision, + self.output_scale, + self.fail_on_error, + ))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{AsArray, Decimal128Array}; + use arrow::datatypes::{Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_expr::expressions::Column; + + fn make_batch(values: Vec>, precision: u8, scale: i8) -> RecordBatch { + let arr = + Decimal128Array::from(values).with_data_type(DataType::Decimal128(precision, scale)); + let schema = Schema::new(vec![Field::new("col", arr.data_type().clone(), true)]); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arr)]).unwrap() + } + + fn eval_expr( + batch: &RecordBatch, + input_scale: i8, + output_precision: u8, + output_scale: i8, + fail_on_error: bool, + ) -> datafusion::common::Result { + let child: Arc = Arc::new(Column::new("col", 0)); + let expr = DecimalRescaleCheckOverflow::new( + child, + input_scale, + output_precision, + output_scale, + fail_on_error, + ); + match expr.evaluate(batch)? { + ColumnarValue::Array(arr) => Ok(arr), + _ => panic!("expected array"), + } + } + + #[test] + fn test_scale_up() { + // Decimal128(10,2) -> Decimal128(10,4): 1.50 (150) -> 1.5000 (15000) + let batch = make_batch(vec![Some(150), Some(-300)], 10, 2); + let result = eval_expr(&batch, 2, 10, 4, false).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 15000); // 1.5000 + assert_eq!(arr.value(1), -30000); // -3.0000 + } + + #[test] + fn test_scale_down_with_half_up_rounding() { + // Decimal128(10,4) -> Decimal128(10,2) + // 1.2350 (12350) -> round to 1.24 (124) with HALF_UP + // 1.2349 (12349) -> round to 1.23 (123) with HALF_UP + // -1.2350 (-12350) -> round to -1.24 (-124) with HALF_UP + let batch = make_batch(vec![Some(12350), Some(12349), Some(-12350)], 10, 4); + let result = eval_expr(&batch, 4, 10, 2, false).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 124); // 1.24 + assert_eq!(arr.value(1), 123); // 1.23 + assert_eq!(arr.value(2), -124); // -1.24 + } + + #[test] + fn test_same_scale_precision_check_only() { + // Same scale, just check precision. Value 999 fits in precision 3, 1000 does not. + let batch = make_batch(vec![Some(999), Some(1000)], 38, 0); + let result = eval_expr(&batch, 0, 3, 0, false).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 999); + assert!(arr.is_null(1)); // overflow -> null in legacy mode + } + + #[test] + fn test_overflow_null_in_legacy_mode() { + // Scale up causes overflow: 10^37 * 100 > i128::MAX range for precision 38 + // Use precision 3, value 10 (which is 10 at scale 0), scale up to scale 2 -> 1000, which overflows precision 3 + let batch = make_batch(vec![Some(10)], 38, 0); + let result = eval_expr(&batch, 0, 3, 2, false).unwrap(); + let arr = result.as_primitive::(); + assert!(arr.is_null(0)); // 10 * 100 = 1000 > 999 (max for precision 3) + } + + #[test] + fn test_overflow_error_in_ansi_mode() { + let batch = make_batch(vec![Some(10)], 38, 0); + let result = eval_expr(&batch, 0, 3, 2, true); + assert!(result.is_err()); + } + + #[test] + fn test_null_propagation() { + let batch = make_batch(vec![Some(100), None, Some(200)], 10, 2); + let result = eval_expr(&batch, 2, 10, 4, false).unwrap(); + let arr = result.as_primitive::(); + assert!(!arr.is_null(0)); + assert!(arr.is_null(1)); + assert!(!arr.is_null(2)); + } + + #[test] + fn test_scalar_path() { + let schema = Schema::new(vec![Field::new("col", DataType::Decimal128(10, 2), true)]); + let batch = RecordBatch::new_empty(Arc::new(schema)); + + let scalar_expr = DecimalRescaleCheckOverflow::new( + Arc::new(ScalarChild(Some(150), 10, 2)), + 2, + 10, + 4, + false, + ); + let result = scalar_expr.evaluate(&batch).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Decimal128(v, p, s)) => { + assert_eq!(v, Some(15000)); + assert_eq!(p, 10); + assert_eq!(s, 4); + } + _ => panic!("expected decimal scalar"), + } + } + + /// Helper expression that always returns a Decimal128 scalar. + #[derive(Debug, Eq, PartialEq, Hash)] + struct ScalarChild(Option, u8, i8); + + impl Display for ScalarChild { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "ScalarChild({:?})", self.0) + } + } + + impl PhysicalExpr for ScalarChild { + fn as_any(&self) -> &dyn Any { + self + } + fn data_type(&self, _: &Schema) -> datafusion::common::Result { + Ok(DataType::Decimal128(self.1, self.2)) + } + fn nullable(&self, _: &Schema) -> datafusion::common::Result { + Ok(true) + } + fn evaluate(&self, _batch: &RecordBatch) -> datafusion::common::Result { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + self.0, self.1, self.2, + ))) + } + fn children(&self) -> Vec<&Arc> { + vec![] + } + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + } + + #[test] + fn test_scalar_null() { + let schema = Schema::new(vec![Field::new("col", DataType::Decimal128(10, 2), true)]); + let batch = RecordBatch::new_empty(Arc::new(schema)); + let expr = + DecimalRescaleCheckOverflow::new(Arc::new(ScalarChild(None, 10, 2)), 2, 10, 4, false); + let result = expr.evaluate(&batch).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => { + assert_eq!(v, None); + } + _ => panic!("expected decimal scalar"), + } + } + + #[test] + fn test_scalar_overflow_legacy() { + let schema = Schema::new(vec![Field::new("col", DataType::Decimal128(38, 0), true)]); + let batch = RecordBatch::new_empty(Arc::new(schema)); + let expr = DecimalRescaleCheckOverflow::new( + Arc::new(ScalarChild(Some(10), 38, 0)), + 0, + 3, + 2, + false, + ); + let result = expr.evaluate(&batch).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => { + assert_eq!(v, None); // 10 * 100 = 1000 > 999 + } + _ => panic!("expected decimal scalar"), + } + } + + #[test] + fn test_scalar_overflow_ansi_returns_error() { + // fail_on_error=true must propagate the error, not silently return None + let schema = Schema::new(vec![Field::new("col", DataType::Decimal128(38, 0), true)]); + let batch = RecordBatch::new_empty(Arc::new(schema)); + let expr = DecimalRescaleCheckOverflow::new( + Arc::new(ScalarChild(Some(10), 38, 0)), + 0, + 3, + 2, + true, // fail_on_error = true + ); + let result = expr.evaluate(&batch); + assert!(result.is_err()); // must be error, not Ok(None) + } + + #[test] + fn test_large_scale_delta_returns_error() { + // delta = output_scale - input_scale = 38 - (-1) = 39 + // 10i128.pow(39) would overflow, so we must reject gracefully + let batch = make_batch(vec![Some(1)], 38, -1); + let result = eval_expr(&batch, -1, 38, 38, false); + assert!(result.is_err()); + } +} diff --git a/native/spark-expr/src/math_funcs/internal/mod.rs b/native/spark-expr/src/math_funcs/internal/mod.rs index 29295f0d52..dff26146e8 100644 --- a/native/spark-expr/src/math_funcs/internal/mod.rs +++ b/native/spark-expr/src/math_funcs/internal/mod.rs @@ -16,11 +16,13 @@ // under the License. mod checkoverflow; +mod decimal_rescale_check; mod make_decimal; mod normalize_nan; mod unscaled_value; pub use checkoverflow::CheckOverflow; +pub use decimal_rescale_check::DecimalRescaleCheckOverflow; pub use make_decimal::spark_make_decimal; pub use normalize_nan::NormalizeNaNAndZero; pub use unscaled_value::spark_unscaled_value; diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index 35c1dc6504..1219bc7208 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -26,6 +26,7 @@ mod negative; mod round; pub(crate) mod unhex; mod utils; +mod wide_decimal_binary_expr; pub use ceil::spark_ceil; pub use div::spark_decimal_div; @@ -36,3 +37,4 @@ pub use modulo_expr::create_modulo_expr; pub use negative::{create_negate_expr, NegativeExpr}; pub use round::spark_round; pub use unhex::spark_unhex; +pub use wide_decimal_binary_expr::{WideDecimalBinaryExpr, WideDecimalOp}; diff --git a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs new file mode 100644 index 0000000000..644252b46b --- /dev/null +++ b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs @@ -0,0 +1,560 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fused wide-decimal binary expression for Decimal128 add/sub/mul that may overflow. +//! +//! Instead of building a 4-node expression tree (Cast→BinaryExpr→Cast→Cast), this performs +//! i256 intermediate arithmetic in a single expression, producing only one output array. + +use crate::math_funcs::utils::get_precision_scale; +use crate::EvalMode; +use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array}; +use arrow::datatypes::{i256, DataType, Decimal128Type, Schema}; +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; +use datafusion::common::Result; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use std::fmt::{Display, Formatter}; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; + +/// The arithmetic operation to perform. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] +pub enum WideDecimalOp { + Add, + Subtract, + Multiply, +} + +impl Display for WideDecimalOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + WideDecimalOp::Add => write!(f, "+"), + WideDecimalOp::Subtract => write!(f, "-"), + WideDecimalOp::Multiply => write!(f, "*"), + } + } +} + +/// A fused expression that evaluates Decimal128 add/sub/mul using i256 intermediate arithmetic, +/// applies scale adjustment with HALF_UP rounding, checks precision bounds, and outputs +/// a single Decimal128 array. +#[derive(Debug, Eq)] +pub struct WideDecimalBinaryExpr { + left: Arc, + right: Arc, + op: WideDecimalOp, + output_precision: u8, + output_scale: i8, + eval_mode: EvalMode, +} + +impl Hash for WideDecimalBinaryExpr { + fn hash(&self, state: &mut H) { + self.left.hash(state); + self.right.hash(state); + self.op.hash(state); + self.output_precision.hash(state); + self.output_scale.hash(state); + self.eval_mode.hash(state); + } +} + +impl PartialEq for WideDecimalBinaryExpr { + fn eq(&self, other: &Self) -> bool { + self.left.eq(&other.left) + && self.right.eq(&other.right) + && self.op == other.op + && self.output_precision == other.output_precision + && self.output_scale == other.output_scale + && self.eval_mode == other.eval_mode + } +} + +impl WideDecimalBinaryExpr { + pub fn new( + left: Arc, + right: Arc, + op: WideDecimalOp, + output_precision: u8, + output_scale: i8, + eval_mode: EvalMode, + ) -> Self { + Self { + left, + right, + op, + output_precision, + output_scale, + eval_mode, + } + } +} + +impl Display for WideDecimalBinaryExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "WideDecimalBinaryExpr [{} {} {}, output: Decimal128({}, {})]", + self.left, self.op, self.right, self.output_precision, self.output_scale + ) + } +} + +/// Compute `value / divisor` with HALF_UP rounding. +#[inline] +fn div_round_half_up(value: i256, divisor: i256) -> i256 { + let (quot, rem) = (value / divisor, value % divisor); + // HALF_UP: if |remainder| * 2 >= |divisor|, round away from zero + let abs_rem_x2 = if rem < i256::ZERO { + rem.wrapping_neg() + } else { + rem + } + .wrapping_mul(i256::from_i128(2)); + let abs_divisor = if divisor < i256::ZERO { + divisor.wrapping_neg() + } else { + divisor + }; + if abs_rem_x2 >= abs_divisor { + if (value < i256::ZERO) != (divisor < i256::ZERO) { + quot.wrapping_sub(i256::ONE) + } else { + quot.wrapping_add(i256::ONE) + } + } else { + quot + } +} + +/// i256 constant for 10. +const I256_TEN: i256 = i256::from_i128(10); + +/// Compute 10^exp as i256. Panics if exp > 76 (max representable power of 10 in i256). +#[inline] +fn i256_pow10(exp: u32) -> i256 { + assert!(exp <= 76, "i256_pow10: exponent {exp} exceeds maximum 76"); + let mut result = i256::ONE; + for _ in 0..exp { + result = result.wrapping_mul(I256_TEN); + } + result +} + +/// Maximum i128 value for a given decimal precision (1-indexed). +/// precision p allows values in [-10^p + 1, 10^p - 1]. +#[inline] +fn max_for_precision(precision: u8) -> i256 { + i256_pow10(precision as u32).wrapping_sub(i256::ONE) +} + +impl PhysicalExpr for WideDecimalBinaryExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Decimal128( + self.output_precision, + self.output_scale, + )) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let left_val = self.left.evaluate(batch)?; + let right_val = self.right.evaluate(batch)?; + + let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (&left_val, &right_val) { + (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)), + (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => { + (l.to_array_of_size(r.len())?, Arc::clone(r)) + } + (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => { + (Arc::clone(l), r.to_array_of_size(l.len())?) + } + (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?), + }; + + let left = left_arr.as_primitive::(); + let right = right_arr.as_primitive::(); + let (_p1, s1) = get_precision_scale(left.data_type()); + let (_p2, s2) = get_precision_scale(right.data_type()); + + let p_out = self.output_precision; + let s_out = self.output_scale; + let op = self.op; + let eval_mode = self.eval_mode; + + let bound = max_for_precision(p_out); + let neg_bound = i256::ZERO.wrapping_sub(bound); + + let result: Decimal128Array = match op { + WideDecimalOp::Add | WideDecimalOp::Subtract => { + let max_scale = std::cmp::max(s1, s2); + let l_scale_up = i256_pow10((max_scale - s1) as u32); + let r_scale_up = i256_pow10((max_scale - s2) as u32); + // After add/sub at max_scale, we may need to rescale to s_out + let scale_diff = max_scale as i16 - s_out as i16; + let (need_scale_down, need_scale_up) = (scale_diff > 0, scale_diff < 0); + let rescale_divisor = if need_scale_down { + i256_pow10(scale_diff as u32) + } else { + i256::ONE + }; + let scale_up_factor = if need_scale_up { + i256_pow10((-scale_diff) as u32) + } else { + i256::ONE + }; + + arrow::compute::kernels::arity::try_binary(left, right, |l, r| { + let l256 = i256::from_i128(l).wrapping_mul(l_scale_up); + let r256 = i256::from_i128(r).wrapping_mul(r_scale_up); + let raw = match op { + WideDecimalOp::Add => l256.wrapping_add(r256), + WideDecimalOp::Subtract => l256.wrapping_sub(r256), + _ => unreachable!(), + }; + let result = if need_scale_down { + div_round_half_up(raw, rescale_divisor) + } else if need_scale_up { + raw.wrapping_mul(scale_up_factor) + } else { + raw + }; + check_overflow_and_convert(result, bound, neg_bound, eval_mode) + })? + } + WideDecimalOp::Multiply => { + let natural_scale = s1 + s2; + let scale_diff = natural_scale as i16 - s_out as i16; + let (need_scale_down, need_scale_up) = (scale_diff > 0, scale_diff < 0); + let rescale_divisor = if need_scale_down { + i256_pow10(scale_diff as u32) + } else { + i256::ONE + }; + let scale_up_factor = if need_scale_up { + i256_pow10((-scale_diff) as u32) + } else { + i256::ONE + }; + + arrow::compute::kernels::arity::try_binary(left, right, |l, r| { + let raw = i256::from_i128(l).wrapping_mul(i256::from_i128(r)); + let result = if need_scale_down { + div_round_half_up(raw, rescale_divisor) + } else if need_scale_up { + raw.wrapping_mul(scale_up_factor) + } else { + raw + }; + check_overflow_and_convert(result, bound, neg_bound, eval_mode) + })? + } + }; + + let result = if eval_mode != EvalMode::Ansi { + result.null_if_overflow_precision(p_out) + } else { + result + }; + let result = result.with_data_type(DataType::Decimal128(p_out, s_out)); + Ok(ColumnarValue::Array(Arc::new(result))) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 2 { + return Err(datafusion::common::DataFusionError::Internal(format!( + "WideDecimalBinaryExpr expects 2 children, got {}", + children.len() + ))); + } + Ok(Arc::new(WideDecimalBinaryExpr::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.op, + self.output_precision, + self.output_scale, + self.eval_mode, + ))) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } +} + +/// Check if the i256 result fits in the output precision. In Ansi mode, return an error +/// on overflow. In Legacy/Try mode, return i128::MAX as a sentinel value that will be +/// nullified by `null_if_overflow_precision`. +#[inline] +fn check_overflow_and_convert( + result: i256, + bound: i256, + neg_bound: i256, + eval_mode: EvalMode, +) -> Result { + if result > bound || result < neg_bound { + if eval_mode == EvalMode::Ansi { + return Err(ArrowError::ComputeError("Arithmetic overflow".to_string())); + } + // Sentinel value — will be nullified by null_if_overflow_precision + Ok(i128::MAX) + } else { + Ok(result.to_i128().unwrap()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Decimal128Array; + use arrow::datatypes::{Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_expr::expressions::Column; + + fn make_batch( + left_values: Vec>, + left_precision: u8, + left_scale: i8, + right_values: Vec>, + right_precision: u8, + right_scale: i8, + ) -> RecordBatch { + let left_arr = Decimal128Array::from(left_values) + .with_data_type(DataType::Decimal128(left_precision, left_scale)); + let right_arr = Decimal128Array::from(right_values) + .with_data_type(DataType::Decimal128(right_precision, right_scale)); + let schema = Schema::new(vec![ + Field::new("left", left_arr.data_type().clone(), true), + Field::new("right", right_arr.data_type().clone(), true), + ]); + RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(left_arr), Arc::new(right_arr)], + ) + .unwrap() + } + + fn eval_expr( + batch: &RecordBatch, + op: WideDecimalOp, + output_precision: u8, + output_scale: i8, + eval_mode: EvalMode, + ) -> Result { + let left: Arc = Arc::new(Column::new("left", 0)); + let right: Arc = Arc::new(Column::new("right", 1)); + let expr = + WideDecimalBinaryExpr::new(left, right, op, output_precision, output_scale, eval_mode); + match expr.evaluate(batch)? { + ColumnarValue::Array(arr) => Ok(arr), + _ => panic!("expected array"), + } + } + + #[test] + fn test_add_same_scale() { + // Decimal128(38, 10) + Decimal128(38, 10) -> Decimal128(38, 10) + let batch = make_batch( + vec![Some(1000000000), Some(2500000000)], // 0.1, 0.25 (scale 10 → divide by 10^10 mentally) + 38, + 10, + vec![Some(2000000000), Some(7500000000)], + 38, + 10, + ); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 10, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 3000000000); // 0.1 + 0.2 + assert_eq!(arr.value(1), 10000000000); // 0.25 + 0.75 + } + + #[test] + fn test_subtract_same_scale() { + let batch = make_batch( + vec![Some(5000), Some(1000)], + 38, + 2, + vec![Some(3000), Some(2000)], + 38, + 2, + ); + let result = eval_expr(&batch, WideDecimalOp::Subtract, 38, 2, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2000); // 50.00 - 30.00 + assert_eq!(arr.value(1), -1000); // 10.00 - 20.00 + } + + #[test] + fn test_add_different_scales() { + // Decimal128(10, 2) + Decimal128(10, 4) -> output scale 4 + let batch = make_batch( + vec![Some(150)], // 1.50 + 10, + 2, + vec![Some(2500)], // 0.2500 + 10, + 4, + ); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 4, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 17500); // 1.5000 + 0.2500 = 1.7500 + } + + #[test] + fn test_multiply_with_scale_reduction() { + // Decimal128(20, 5) * Decimal128(20, 5) -> natural scale 10, output scale 6 + // 1.00000 * 2.00000 = 2.000000 + let batch = make_batch( + vec![Some(100000)], // 1.00000 + 20, + 5, + vec![Some(200000)], // 2.00000 + 20, + 5, + ); + let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 6, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2000000); // 2.000000 + } + + #[test] + fn test_multiply_half_up_rounding() { + // Test HALF_UP rounding: 1.5 * 1.5 = 2.25, but if output scale=1, should round to 2.3 + // Input: scale 1, values 15 (1.5) * 15 (1.5) = natural scale 2, raw = 225 + // Output scale 1: 225 / 10 = 22 remainder 5 -> HALF_UP rounds to 23 + let batch = make_batch( + vec![Some(15)], // 1.5 + 10, + 1, + vec![Some(15)], // 1.5 + 10, + 1, + ); + let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 1, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 23); // 2.3 + } + + #[test] + fn test_multiply_half_up_rounding_negative() { + // -1.5 * 1.5 = -2.25, output scale 1: -225/10 => -22 rem -5 -> HALF_UP rounds to -23 + let batch = make_batch( + vec![Some(-15)], // -1.5 + 10, + 1, + vec![Some(15)], // 1.5 + 10, + 1, + ); + let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 1, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), -23); // -2.3 + } + + #[test] + fn test_overflow_legacy_mode_returns_null() { + // Use precision 1 (max value 9), so 5 + 5 = 10 overflows + let batch = make_batch(vec![Some(5)], 38, 0, vec![Some(5)], 38, 0); + let result = eval_expr(&batch, WideDecimalOp::Add, 1, 0, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert!(arr.is_null(0)); + } + + #[test] + fn test_overflow_ansi_mode_returns_error() { + let batch = make_batch(vec![Some(5)], 38, 0, vec![Some(5)], 38, 0); + let result = eval_expr(&batch, WideDecimalOp::Add, 1, 0, EvalMode::Ansi); + assert!(result.is_err()); + } + + #[test] + fn test_null_propagation() { + let batch = make_batch(vec![Some(100), None], 10, 2, vec![None, Some(200)], 10, 2); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 2, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert!(arr.is_null(0)); + assert!(arr.is_null(1)); + } + + #[test] + fn test_zeros() { + let batch = make_batch(vec![Some(0)], 38, 10, vec![Some(0)], 38, 10); + let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 10, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); + } + + #[test] + fn test_max_precision_values() { + // Max Decimal128(38,0) value: 10^38 - 1 + let max_val = 10i128.pow(38) - 1; + let batch = make_batch(vec![Some(max_val)], 38, 0, vec![Some(0)], 38, 0); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 0, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), max_val); + } + + #[test] + fn test_add_scale_up_to_output() { + // When s_out > max(s1, s2), result must be scaled UP + // Decimal128(10, 2) + Decimal128(10, 2) with output scale 4 + // 1.50 + 0.25 = 1.75, at scale 4 = 17500 + let batch = make_batch( + vec![Some(150)], // 1.50 + 10, + 2, + vec![Some(25)], // 0.25 + 10, + 2, + ); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 4, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 17500); // 1.7500 + } + + #[test] + fn test_subtract_scale_up_to_output() { + // s_out (4) > max(s1, s2) (2) — verify scale-up path for subtract + let batch = make_batch( + vec![Some(300)], // 3.00 + 10, + 2, + vec![Some(100)], // 1.00 + 10, + 2, + ); + let result = eval_expr(&batch, WideDecimalOp::Subtract, 38, 4, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 20000); // 2.0000 + } +}