Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kernel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ tracing = { version = "0.1", features = ["log"] }
url = "2"
uuid = { version = "1.18.0", features = ["v4", "fast-rng"] }
z85 = "3.0.6"
regex = "1"

# optional deps
futures = { version = "0.3", optional = true }
Expand Down
162 changes: 162 additions & 0 deletions kernel/src/engine/arrow_expression/evaluate_expression.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//! Expression handling based on arrow-rs compute kernels.
use std::borrow::Cow;
use std::collections::HashSet;
use std::sync::Arc;

use itertools::Itertools;
use regex::Regex;

use crate::arrow::array::types::*;
use crate::arrow::array::{
Expand Down Expand Up @@ -436,6 +438,72 @@ pub fn evaluate_predicate(
}
}

/// Collects field names that have Decimal type with scale=0 from Arrow schema.
/// This recursively traverses nested structs to find all scale=0 decimal fields.
fn collect_scale_zero_decimal_fields(fields: &ArrowFields) -> HashSet<String> {
let mut result = HashSet::new();

fn collect_recursive(fields: &ArrowFields, prefix: &str, result: &mut HashSet<String>) {
for field in fields.iter() {
let field_name = if prefix.is_empty() {
field.name().to_string()
} else {
format!("{}.{}", prefix, field.name())
};

match field.data_type() {
ArrowDataType::Decimal128(_, scale) | ArrowDataType::Decimal256(_, scale) => {
if *scale == 0 {
result.insert(field.name().to_string());
}
}
ArrowDataType::Struct(nested_fields) => {
collect_recursive(nested_fields, &field_name, result);
}
_ => {}
}
}
}

collect_recursive(fields, "", &mut result);
result
}

/// Post-processes JSON bytes to strip ".0" suffix from decimal fields with scale=0.
/// This fixes the issue where Arrow's JSON encoder writes "1234.0" for Decimal(_, 0),
/// but Arrow's JSON parser rejects this format and expects "1234" without a decimal point.
fn fix_scale_zero_decimals_in_json(
json_bytes: &[u8],
scale_zero_fields: &HashSet<String>,
) -> Result<Vec<u8>, ArrowError> {
if scale_zero_fields.is_empty() {
return Ok(json_bytes.to_vec());
}

let json_str = std::str::from_utf8(json_bytes)
.map_err(|e| ArrowError::InvalidArgumentError(format!("Invalid UTF-8 in JSON: {}", e)))?;

let mut fixed = json_str.to_string();

// For each scale=0 decimal field, replace "fieldname": NUMBER.0 with "fieldname": NUMBER
// The regex pattern matches: "fieldname": followed by optional whitespace, then a number,
// then .0 (which might be followed by more zeros), then a trailing context character
// We capture the context character and include it in the replacement
for field_name in scale_zero_fields {
let pattern = format!(
r#""{}"\s*:\s*(-?\d+)\.0+([,\s\}}])"#,
regex::escape(field_name)
);
let re = Regex::new(&pattern)
.map_err(|e| ArrowError::InvalidArgumentError(format!("Regex error: {}", e)))?;
fixed = re
.replace_all(&fixed, format!(r#""{}":${{1}}${{2}}"#, field_name))
.to_string();
}

Ok(fixed.into_bytes())
}

/// Converts a StructArray to JSON-encoded strings
pub fn to_json(input: &dyn Datum) -> Result<ArrayRef, ArrowError> {
let (array_ref, _is_scalar) = input.get();
Expand All @@ -462,6 +530,9 @@ pub fn to_json(input: &dyn Datum) -> Result<ArrayRef, ArrowError> {
let options = EncoderOptions::default().with_struct_mode(StructMode::ObjectOnly);
let mut encoder = make_encoder(&field, struct_array, &options)?;

// Identify fields with Decimal(_, 0) that need post-processing
let scale_zero_fields = collect_scale_zero_decimal_fields(struct_array.fields());

// Pre-allocate the various buffers
const ROW_SIZE_ESTIMATE: usize = 64;
let mut data = Vec::with_capacity(num_rows * ROW_SIZE_ESTIMATE);
Expand All @@ -473,7 +544,19 @@ pub fn to_json(input: &dyn Datum) -> Result<ArrayRef, ArrowError> {
if struct_array.is_null(i) {
nulls.append_null();
} else {
let start_offset = data.len();
encoder.encode(i, &mut data);

// Post-process the JSON bytes that were just added to fix scale=0 decimals
if !scale_zero_fields.is_empty() {
let row_json = &data[start_offset..];
let fixed_json =
fix_scale_zero_decimals_in_json(row_json, &scale_zero_fields)?;
// Replace the bytes in the buffer
data.truncate(start_offset);
data.extend_from_slice(&fixed_json);
}

nulls.append_non_null();
}

Expand Down Expand Up @@ -1099,4 +1182,83 @@ mod tests {
validate_i32_column(nested_struct_result, 0, &[1, 2, 3]);
validate_i32_column(nested_struct_result, 1, &[10, 20, 30]);
}

#[test]
fn test_to_json_decimal_scale_zero() {
use crate::arrow::array::Decimal128Array;
use crate::arrow::datatypes::DataType as ArrowDataType;

// Create arrays with sample data
// For Decimal(10, 0): value 1234 stored as 1234 (no scaling)
// For Decimal(10, 2): value 12.34 stored as 1234 (scaled by 10^2)
let decimal_scale0 = Arc::new(
Decimal128Array::from(vec![1234, 5678])
.with_precision_and_scale(10, 0)
.unwrap(),
);
let decimal_scale2 = Arc::new(
Decimal128Array::from(vec![1234, 5678])
.with_precision_and_scale(10, 2)
.unwrap(),
);
let int_array = Arc::new(Int32Array::from(vec![42, 99]));

let struct_array = StructArray::from(vec![
(
Arc::new(ArrowField::new(
"decimalScale0",
ArrowDataType::Decimal128(10, 0),
false,
)),
decimal_scale0 as ArrayRef,
),
(
Arc::new(ArrowField::new(
"decimalScale2",
ArrowDataType::Decimal128(10, 2),
false,
)),
decimal_scale2 as ArrayRef,
),
(
Arc::new(ArrowField::new("normalInt", ArrowDataType::Int32, false)),
int_array as ArrayRef,
),
]);

// Convert to JSON
let result = to_json(&struct_array).unwrap();
let json_array = result.as_any().downcast_ref::<StringArray>().unwrap();

// Check the JSON strings
assert_eq!(json_array.len(), 2);

// First row
let json1 = json_array.value(0);
assert!(
json1.contains(r#""decimalScale0":1234"#),
"Scale=0 decimal should NOT have .0, got: {}",
json1
);
assert!(
json1.contains(r#""decimalScale2":12.34"#),
"Scale=2 decimal SHOULD have decimal point, got: {}",
json1
);
assert!(json1.contains(r#""normalInt":42"#));

// Second row
let json2 = json_array.value(1);
assert!(
json2.contains(r#""decimalScale0":5678"#),
"Scale=0 decimal should NOT have .0, got: {}",
json2
);
assert!(
json2.contains(r#""decimalScale2":56.78"#),
"Scale=2 decimal SHOULD have decimal point, got: {}",
json2
);
assert!(json2.contains(r#""normalInt":99"#));
}
}
Loading