diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 65f320c4f9f13..db5cd567880a5 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::utils::utf8_to_str_type; use arrow::array::{ Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, - OffsetSizeTrait, StringArrayType, StringViewArray, + StringArrayType, StringLikeArrayBuilder, StringViewArray, StringViewBuilder, }; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; @@ -96,6 +96,9 @@ impl ScalarUDFImpl for RepeatFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types[0] == Utf8View { + return Ok(Utf8View); + } utf8_to_str_type(&arg_types[0], "repeat") } @@ -131,13 +134,12 @@ impl ScalarUDFImpl for RepeatFunc { }; let result = match string_scalar { - ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => { - ScalarValue::Utf8(Some(compute_repeat( - s, - count, - i32::MAX as usize, - )?)) - } + ScalarValue::Utf8View(Some(s)) => ScalarValue::Utf8View(Some( + compute_repeat(s, count, i32::MAX as usize)?, + )), + ScalarValue::Utf8(Some(s)) => ScalarValue::Utf8(Some( + compute_repeat(s, count, i32::MAX as usize)?, + )), ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some( compute_repeat(s, count, i64::MAX as usize)?, )), @@ -188,26 +190,47 @@ fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result { match string_array.data_type() { Utf8View => { let string_view_array = string_array.as_string_view(); - repeat_impl::( + let (_, max_item_capacity) = calculate_capacities( &string_view_array, number_array, i32::MAX as usize, + )?; + let builder = StringViewBuilder::with_capacity(string_array.len()); + repeat_impl::<&StringViewArray, StringViewBuilder>( + &string_view_array, + number_array, + max_item_capacity, + builder, ) } Utf8 => { let string_arr = string_array.as_string::(); - repeat_impl::>( + let (total_capacity, max_item_capacity) = + calculate_capacities(&string_arr, number_array, i32::MAX as usize)?; + let builder = GenericStringBuilder::::with_capacity( + string_array.len(), + total_capacity, + ); + repeat_impl::<&GenericStringArray, GenericStringBuilder>( &string_arr, number_array, - i32::MAX as usize, + max_item_capacity, + builder, ) } LargeUtf8 => { let string_arr = string_array.as_string::(); - repeat_impl::>( + let (total_capacity, max_item_capacity) = + calculate_capacities(&string_arr, number_array, i64::MAX as usize)?; + let builder = GenericStringBuilder::::with_capacity( + string_array.len(), + total_capacity, + ); + repeat_impl::<&GenericStringArray, GenericStringBuilder>( &string_arr, number_array, - i64::MAX as usize, + max_item_capacity, + builder, ) } other => exec_err!( @@ -217,17 +240,17 @@ fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result { } } -fn repeat_impl<'a, T, S>( +fn calculate_capacities<'a, S>( string_array: &S, number_array: &Int64Array, max_str_len: usize, -) -> Result +) -> Result<(usize, usize)> where - T: OffsetSizeTrait, - S: StringArrayType<'a> + 'a, + S: StringArrayType<'a>, { let mut total_capacity = 0; let mut max_item_capacity = 0; + string_array.iter().zip(number_array.iter()).try_for_each( |(string, number)| -> Result<(), DataFusionError> { match (string, number) { @@ -249,9 +272,19 @@ where }, )?; - let mut builder = - GenericStringBuilder::::with_capacity(string_array.len(), total_capacity); + Ok((total_capacity, max_item_capacity)) +} +fn repeat_impl<'a, S, B>( + string_array: &S, + number_array: &Int64Array, + max_item_capacity: usize, + mut builder: B, +) -> Result +where + S: StringArrayType<'a> + 'a, + B: StringLikeArrayBuilder, +{ // Reusable buffer to avoid allocations in string.repeat() let mut buffer = Vec::::with_capacity(max_item_capacity); @@ -308,8 +341,8 @@ where #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, LargeStringArray, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::ScalarValue; use datafusion_common::{Result, exec_err}; @@ -362,8 +395,8 @@ mod tests { ], Ok(Some("PgPgPgPg")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RepeatFunc::new(), @@ -373,8 +406,19 @@ mod tests { ], Ok(None), &str, - Utf8, - StringArray + Utf8View, + StringViewArray + ); + test_function!( + RepeatFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + LargeUtf8, + LargeStringArray ); test_function!( RepeatFunc::new(), @@ -384,8 +428,8 @@ mod tests { ], Ok(None), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RepeatFunc::new(), diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index a07eab3357141..e8e08319e147c 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -347,11 +347,35 @@ SELECT repeat('foo', 3) ---- foofoofoo +query T +SELECT repeat(arrow_cast('foo', 'LargeUtf8'), 3) +---- +foofoofoo + +query T +SELECT repeat(arrow_cast('foo', 'Utf8View'), 3) +---- +foofoofoo + query T SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3) ---- foofoofoo +query T +SELECT arrow_typeof(repeat('foo', 3)) +---- +Utf8 + +query T +SELECT arrow_typeof(repeat(arrow_cast('foo', 'LargeUtf8'), 3)) +---- +LargeUtf8 + +query T +SELECT arrow_typeof(repeat(arrow_cast('foo', 'Utf8View'), 3)) +---- +Utf8View query T SELECT replace('foobar', 'bar', 'hello')