diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 39e1e13ce39a..491fac272c2c 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -24,9 +24,10 @@ use datafusion::{ logical_expr::Volatility, }; +use datafusion::error::Result; use datafusion::prelude::*; -use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; use datafusion_common::cast::as_float64_array; +use datafusion_expr::ColumnarValue; use std::sync::Arc; /// create local execution context with an in-memory table: @@ -61,7 +62,7 @@ async fn main() -> Result<()> { let ctx = create_context()?; // First, declare the actual implementation of the calculation - let pow = |args: &[ArrayRef]| { + let pow = Arc::new(|args: &[ColumnarValue]| { // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: // 1. cast the values to the type we want // 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result @@ -69,9 +70,22 @@ async fn main() -> Result<()> { // this is guaranteed by DataFusion based on the function's signature. assert_eq!(args.len(), 2); + // Try to obtain row number + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let inferred_length = len.unwrap_or(1); + + let arg0 = args[0].clone().into_array(inferred_length)?; + let arg1 = args[1].clone().into_array(inferred_length)?; + // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! - let base = as_float64_array(&args[0]).expect("cast failed"); - let exponent = as_float64_array(&args[1]).expect("cast failed"); + let base = as_float64_array(&arg0).expect("cast failed"); + let exponent = as_float64_array(&arg1).expect("cast failed"); // this is guaranteed by DataFusion. We place it just to make it obvious. assert_eq!(exponent.len(), base.len()); @@ -92,11 +106,8 @@ async fn main() -> Result<()> { // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) // `Arc` because arrays are immutable, thread-safe, trait objects. - Ok(Arc::new(array) as ArrayRef) - }; - // the function above expects an `ArrayRef`, but DataFusion may pass a scalar to a UDF. - // thus, we use `make_scalar_function` to decorare the closure so that it can handle both Arrays and Scalar values. - let pow = make_scalar_function(pow); + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) + }); // Next: // * give it a name so that it shows nicely when the plan is printed diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index fe88ea6cf115..b8573a690e7b 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -19,10 +19,7 @@ use arrow::compute::kernels::numeric::add; use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::*; -use datafusion::{ - execution::registry::FunctionRegistry, - physical_plan::functions::make_scalar_function, test_util, -}; +use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::as_float64_array; use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue}; use datafusion_expr::{ @@ -87,12 +84,18 @@ async fn scalar_udf() -> Result<()> { ctx.register_batch("t", batch)?; - let myfunc = |args: &[ArrayRef]| { - let l = as_int32_array(&args[0])?; - let r = as_int32_array(&args[1])?; - Ok(Arc::new(add(l, r)?) as ArrayRef) - }; - let myfunc = make_scalar_function(myfunc); + let myfunc = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(l) = &args[0] else { + panic!("should be array") + }; + let ColumnarValue::Array(r) = &args[1] else { + panic!("should be array") + }; + + let l = as_int32_array(l)?; + let r = as_int32_array(r)?; + Ok(ColumnarValue::from(Arc::new(add(l, r)?) as ArrayRef)) + }); ctx.register_udf(create_udf( "my_add", @@ -163,11 +166,14 @@ async fn scalar_udf_zero_params() -> Result<()> { ctx.register_batch("t", batch)?; // create function just returns 100 regardless of inp - let myfunc = |args: &[ArrayRef]| { - let num_rows = args[0].len(); - Ok(Arc::new((0..num_rows).map(|_| 100).collect::()) as ArrayRef) - }; - let myfunc = make_scalar_function(myfunc); + let myfunc = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Scalar(_) = &args[0] else { + panic!("expect scalar") + }; + Ok(ColumnarValue::Array( + Arc::new((0..1).map(|_| 100).collect::()) as ArrayRef, + )) + }); ctx.register_udf(create_udf( "get_100", @@ -307,8 +313,12 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; ctx.register_batch("t", batch).unwrap(); - let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); - let myfunc = make_scalar_function(myfunc); + let myfunc = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!("should be array") + }; + Ok(ColumnarValue::from(Arc::clone(array))) + }); ctx.register_udf(create_udf( "MY_FUNC", @@ -348,8 +358,12 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; ctx.register_batch("t", batch).unwrap(); - let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); - let myfunc = make_scalar_function(myfunc); + let myfunc = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!("should be array") + }; + Ok(ColumnarValue::from(Arc::clone(array))) + }); let udf = create_udf( "dummy", diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr/src/columnar_value.rs index 7a2883928169..58c534b50aad 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr/src/columnar_value.rs @@ -37,6 +37,18 @@ pub enum ColumnarValue { Scalar(ScalarValue), } +impl From for ColumnarValue { + fn from(value: ArrayRef) -> Self { + ColumnarValue::Array(value) + } +} + +impl From for ColumnarValue { + fn from(value: ScalarValue) -> Self { + ColumnarValue::Scalar(value) + } +} + impl ColumnarValue { pub fn data_type(&self) -> DataType { match self { diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 3ba343003e33..674e85a55c92 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1321,9 +1321,7 @@ mod tests { assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema, }; use datafusion_expr::{interval_arithmetic::Interval, *}; - use datafusion_physical_expr::{ - execution_props::ExecutionProps, functions::make_scalar_function, - }; + use datafusion_physical_expr::execution_props::ExecutionProps; use chrono::{DateTime, TimeZone, Utc}; @@ -1438,9 +1436,31 @@ mod tests { let input_types = vec![DataType::Int32, DataType::Int32]; let return_type = Arc::new(DataType::Int32); - let fun = |args: &[ArrayRef]| { - let arg0 = as_int32_array(&args[0])?; - let arg1 = as_int32_array(&args[1])?; + let fun = Arc::new(|args: &[ColumnarValue]| { + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let inferred_length = len.unwrap_or(1); + + let arg0 = match &args[0] { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.to_array_of_size(inferred_length).unwrap() + } + }; + let arg1 = match &args[1] { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.to_array_of_size(inferred_length).unwrap() + } + }; + + let arg0 = as_int32_array(&arg0)?; + let arg1 = as_int32_array(&arg1)?; // 2. perform the computation let array = arg0 @@ -1456,10 +1476,9 @@ mod tests { }) .collect::(); - Ok(Arc::new(array) as ArrayRef) - }; + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) + }); - let fun = make_scalar_function(fun); Arc::new(create_udf( "udf_add", input_types, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 66e22d2302de..d1e75bfe4f56 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -191,9 +191,23 @@ pub(crate) enum Hint { AcceptsSingular, } -/// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function +/// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function /// and vice-versa after evaluation. +/// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar. +/// That's said its output will be same for all input rows in a batch. +#[deprecated( + since = "36.0.0", + note = "Implement your function directly in terms of ColumnarValue or use `ScalarUDF` instead" +)] pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + make_scalar_function_inner(inner) +} + +/// Internal implementation, see comments on `make_scalar_function` for caveats +pub(crate) fn make_scalar_function_inner(inner: F) -> ScalarFunctionImplementation where F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, { @@ -260,9 +274,9 @@ pub fn create_physical_fun( ) -> Result { Ok(match fun { // math functions - BuiltinScalarFunction::Abs => { - Arc::new(|args| make_scalar_function(math_expressions::abs_invoke)(args)) - } + BuiltinScalarFunction::Abs => Arc::new(|args| { + make_scalar_function_inner(math_expressions::abs_invoke)(args) + }), BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos), BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin), BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), @@ -275,31 +289,31 @@ pub fn create_physical_fun( BuiltinScalarFunction::Degrees => Arc::new(math_expressions::to_degrees), BuiltinScalarFunction::Exp => Arc::new(math_expressions::exp), BuiltinScalarFunction::Factorial => { - Arc::new(|args| make_scalar_function(math_expressions::factorial)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::factorial)(args)) } BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor), BuiltinScalarFunction::Gcd => { - Arc::new(|args| make_scalar_function(math_expressions::gcd)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::gcd)(args)) } BuiltinScalarFunction::Isnan => { - Arc::new(|args| make_scalar_function(math_expressions::isnan)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::isnan)(args)) } BuiltinScalarFunction::Iszero => { - Arc::new(|args| make_scalar_function(math_expressions::iszero)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::iszero)(args)) } BuiltinScalarFunction::Lcm => { - Arc::new(|args| make_scalar_function(math_expressions::lcm)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::lcm)(args)) } BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), BuiltinScalarFunction::Nanvl => { - Arc::new(|args| make_scalar_function(math_expressions::nanvl)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::nanvl)(args)) } BuiltinScalarFunction::Radians => Arc::new(math_expressions::to_radians), BuiltinScalarFunction::Random => Arc::new(math_expressions::random), BuiltinScalarFunction::Round => { - Arc::new(|args| make_scalar_function(math_expressions::round)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::round)(args)) } BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum), BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin), @@ -309,135 +323,135 @@ pub fn create_physical_fun( BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), BuiltinScalarFunction::Tanh => Arc::new(math_expressions::tanh), BuiltinScalarFunction::Trunc => { - Arc::new(|args| make_scalar_function(math_expressions::trunc)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args)) } BuiltinScalarFunction::Pi => Arc::new(math_expressions::pi), BuiltinScalarFunction::Power => { - Arc::new(|args| make_scalar_function(math_expressions::power)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::power)(args)) } BuiltinScalarFunction::Atan2 => { - Arc::new(|args| make_scalar_function(math_expressions::atan2)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::atan2)(args)) } BuiltinScalarFunction::Log => { - Arc::new(|args| make_scalar_function(math_expressions::log)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::log)(args)) } BuiltinScalarFunction::Cot => { - Arc::new(|args| make_scalar_function(math_expressions::cot)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } // array functions - BuiltinScalarFunction::ArrayAppend => { - Arc::new(|args| make_scalar_function(array_expressions::array_append)(args)) - } - BuiltinScalarFunction::ArraySort => { - Arc::new(|args| make_scalar_function(array_expressions::array_sort)(args)) - } - BuiltinScalarFunction::ArrayConcat => { - Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args)) - } - BuiltinScalarFunction::ArrayEmpty => { - Arc::new(|args| make_scalar_function(array_expressions::array_empty)(args)) - } - BuiltinScalarFunction::ArrayHasAll => { - Arc::new(|args| make_scalar_function(array_expressions::array_has_all)(args)) - } - BuiltinScalarFunction::ArrayHasAny => { - Arc::new(|args| make_scalar_function(array_expressions::array_has_any)(args)) - } - BuiltinScalarFunction::ArrayHas => { - Arc::new(|args| make_scalar_function(array_expressions::array_has)(args)) - } - BuiltinScalarFunction::ArrayDims => { - Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) - } - BuiltinScalarFunction::ArrayDistinct => { - Arc::new(|args| make_scalar_function(array_expressions::array_distinct)(args)) - } - BuiltinScalarFunction::ArrayElement => { - Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) - } - BuiltinScalarFunction::ArrayExcept => { - Arc::new(|args| make_scalar_function(array_expressions::array_except)(args)) - } - BuiltinScalarFunction::ArrayLength => { - Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) - } + BuiltinScalarFunction::ArrayAppend => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_append)(args) + }), + BuiltinScalarFunction::ArraySort => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_sort)(args) + }), + BuiltinScalarFunction::ArrayConcat => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_concat)(args) + }), + BuiltinScalarFunction::ArrayEmpty => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_empty)(args) + }), + BuiltinScalarFunction::ArrayHasAll => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_has_all)(args) + }), + BuiltinScalarFunction::ArrayHasAny => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_has_any)(args) + }), + BuiltinScalarFunction::ArrayHas => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_has)(args) + }), + BuiltinScalarFunction::ArrayDims => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_dims)(args) + }), + BuiltinScalarFunction::ArrayDistinct => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_distinct)(args) + }), + BuiltinScalarFunction::ArrayElement => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_element)(args) + }), + BuiltinScalarFunction::ArrayExcept => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_except)(args) + }), + BuiltinScalarFunction::ArrayLength => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_length)(args) + }), BuiltinScalarFunction::Flatten => { - Arc::new(|args| make_scalar_function(array_expressions::flatten)(args)) - } - BuiltinScalarFunction::ArrayNdims => { - Arc::new(|args| make_scalar_function(array_expressions::array_ndims)(args)) + Arc::new(|args| make_scalar_function_inner(array_expressions::flatten)(args)) } + BuiltinScalarFunction::ArrayNdims => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_ndims)(args) + }), BuiltinScalarFunction::ArrayPopFront => Arc::new(|args| { - make_scalar_function(array_expressions::array_pop_front)(args) + make_scalar_function_inner(array_expressions::array_pop_front)(args) + }), + BuiltinScalarFunction::ArrayPopBack => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_pop_back)(args) + }), + BuiltinScalarFunction::ArrayPosition => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_position)(args) }), - BuiltinScalarFunction::ArrayPopBack => { - Arc::new(|args| make_scalar_function(array_expressions::array_pop_back)(args)) - } - BuiltinScalarFunction::ArrayPosition => { - Arc::new(|args| make_scalar_function(array_expressions::array_position)(args)) - } BuiltinScalarFunction::ArrayPositions => Arc::new(|args| { - make_scalar_function(array_expressions::array_positions)(args) + make_scalar_function_inner(array_expressions::array_positions)(args) + }), + BuiltinScalarFunction::ArrayPrepend => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_prepend)(args) + }), + BuiltinScalarFunction::ArrayRepeat => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_repeat)(args) + }), + BuiltinScalarFunction::ArrayRemove => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_remove)(args) + }), + BuiltinScalarFunction::ArrayRemoveN => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_remove_n)(args) }), - BuiltinScalarFunction::ArrayPrepend => { - Arc::new(|args| make_scalar_function(array_expressions::array_prepend)(args)) - } - BuiltinScalarFunction::ArrayRepeat => { - Arc::new(|args| make_scalar_function(array_expressions::array_repeat)(args)) - } - BuiltinScalarFunction::ArrayRemove => { - Arc::new(|args| make_scalar_function(array_expressions::array_remove)(args)) - } - BuiltinScalarFunction::ArrayRemoveN => { - Arc::new(|args| make_scalar_function(array_expressions::array_remove_n)(args)) - } BuiltinScalarFunction::ArrayRemoveAll => Arc::new(|args| { - make_scalar_function(array_expressions::array_remove_all)(args) + make_scalar_function_inner(array_expressions::array_remove_all)(args) + }), + BuiltinScalarFunction::ArrayReplace => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_replace)(args) }), - BuiltinScalarFunction::ArrayReplace => { - Arc::new(|args| make_scalar_function(array_expressions::array_replace)(args)) - } BuiltinScalarFunction::ArrayReplaceN => Arc::new(|args| { - make_scalar_function(array_expressions::array_replace_n)(args) + make_scalar_function_inner(array_expressions::array_replace_n)(args) }), BuiltinScalarFunction::ArrayReplaceAll => Arc::new(|args| { - make_scalar_function(array_expressions::array_replace_all)(args) + make_scalar_function_inner(array_expressions::array_replace_all)(args) + }), + BuiltinScalarFunction::ArraySlice => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_slice)(args) }), - BuiltinScalarFunction::ArraySlice => { - Arc::new(|args| make_scalar_function(array_expressions::array_slice)(args)) - } BuiltinScalarFunction::ArrayToString => Arc::new(|args| { - make_scalar_function(array_expressions::array_to_string)(args) + make_scalar_function_inner(array_expressions::array_to_string)(args) }), BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { - make_scalar_function(array_expressions::array_intersect)(args) + make_scalar_function_inner(array_expressions::array_intersect)(args) + }), + BuiltinScalarFunction::Range => Arc::new(|args| { + make_scalar_function_inner(array_expressions::gen_range)(args) + }), + BuiltinScalarFunction::Cardinality => Arc::new(|args| { + make_scalar_function_inner(array_expressions::cardinality)(args) + }), + BuiltinScalarFunction::ArrayResize => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_resize)(args) + }), + BuiltinScalarFunction::MakeArray => Arc::new(|args| { + make_scalar_function_inner(array_expressions::make_array)(args) + }), + BuiltinScalarFunction::ArrayUnion => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_union)(args) }), - BuiltinScalarFunction::Range => { - Arc::new(|args| make_scalar_function(array_expressions::gen_range)(args)) - } - BuiltinScalarFunction::Cardinality => { - Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args)) - } - BuiltinScalarFunction::ArrayResize => { - Arc::new(|args| make_scalar_function(array_expressions::array_resize)(args)) - } - BuiltinScalarFunction::MakeArray => { - Arc::new(|args| make_scalar_function(array_expressions::make_array)(args)) - } - BuiltinScalarFunction::ArrayUnion => { - Arc::new(|args| make_scalar_function(array_expressions::array_union)(args)) - } // struct functions BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), // string functions BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::ascii::)(args) + make_scalar_function_inner(string_expressions::ascii::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::ascii::)(args) + make_scalar_function_inner(string_expressions::ascii::)(args) } other => internal_err!("Unsupported data type {other:?} for function ascii"), }), @@ -455,10 +469,10 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::Btrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function_inner(string_expressions::btrim::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function_inner(string_expressions::btrim::)(args) } other => internal_err!("Unsupported data type {other:?} for function btrim"), }), @@ -470,7 +484,7 @@ pub fn create_physical_fun( Int32Type, "character_length" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -478,7 +492,7 @@ pub fn create_physical_fun( Int64Type, "character_length" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!( "Unsupported data type {other:?} for function character_length" @@ -486,13 +500,13 @@ pub fn create_physical_fun( }) } BuiltinScalarFunction::Chr => { - Arc::new(|args| make_scalar_function(string_expressions::chr)(args)) + Arc::new(|args| make_scalar_function_inner(string_expressions::chr)(args)) } BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), - BuiltinScalarFunction::ConcatWithSeparator => { - Arc::new(|args| make_scalar_function(string_expressions::concat_ws)(args)) - } + BuiltinScalarFunction::ConcatWithSeparator => Arc::new(|args| { + make_scalar_function_inner(string_expressions::concat_ws)(args) + }), BuiltinScalarFunction::DatePart => Arc::new(datetime_expressions::date_part), BuiltinScalarFunction::DateTrunc => Arc::new(datetime_expressions::date_trunc), BuiltinScalarFunction::DateBin => Arc::new(datetime_expressions::date_bin), @@ -534,10 +548,10 @@ pub fn create_physical_fun( } BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::initcap::)(args) + make_scalar_function_inner(string_expressions::initcap::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::initcap::)(args) + make_scalar_function_inner(string_expressions::initcap::)(args) } other => { internal_err!("Unsupported data type {other:?} for function initcap") @@ -546,11 +560,11 @@ pub fn create_physical_fun( BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function left"), }), @@ -558,20 +572,20 @@ pub fn create_physical_fun( BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function lpad"), }), BuiltinScalarFunction::Ltrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::ltrim::)(args) + make_scalar_function_inner(string_expressions::ltrim::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::ltrim::)(args) + make_scalar_function_inner(string_expressions::ltrim::)(args) } other => internal_err!("Unsupported data type {other:?} for function ltrim"), }), @@ -608,7 +622,7 @@ pub fn create_physical_fun( i32, "regexp_match" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_on_array_if_regex_expressions_feature_flag!( @@ -616,7 +630,7 @@ pub fn create_physical_fun( i64, "regexp_match" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!( "Unsupported data type {other:?} for function regexp_match" @@ -650,19 +664,19 @@ pub fn create_physical_fun( } BuiltinScalarFunction::Repeat => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::repeat::)(args) + make_scalar_function_inner(string_expressions::repeat::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::repeat::)(args) + make_scalar_function_inner(string_expressions::repeat::)(args) } other => internal_err!("Unsupported data type {other:?} for function repeat"), }), BuiltinScalarFunction::Replace => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::replace::)(args) + make_scalar_function_inner(string_expressions::replace::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::replace::)(args) + make_scalar_function_inner(string_expressions::replace::)(args) } other => { internal_err!("Unsupported data type {other:?} for function replace") @@ -672,12 +686,12 @@ pub fn create_physical_fun( DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => { internal_err!("Unsupported data type {other:?} for function reverse") @@ -687,32 +701,32 @@ pub fn create_physical_fun( DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function right"), }), BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function rpad"), }), BuiltinScalarFunction::Rtrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::rtrim::)(args) + make_scalar_function_inner(string_expressions::rtrim::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::rtrim::)(args) + make_scalar_function_inner(string_expressions::rtrim::)(args) } other => internal_err!("Unsupported data type {other:?} for function rtrim"), }), @@ -730,10 +744,10 @@ pub fn create_physical_fun( } BuiltinScalarFunction::SplitPart => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::split_part::)(args) + make_scalar_function_inner(string_expressions::split_part::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::split_part::)(args) + make_scalar_function_inner(string_expressions::split_part::)(args) } other => { internal_err!("Unsupported data type {other:?} for function split_part") @@ -741,12 +755,12 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::StringToArray => { Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(array_expressions::string_to_array::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(array_expressions::string_to_array::)(args) - } + DataType::Utf8 => make_scalar_function_inner( + array_expressions::string_to_array::, + )(args), + DataType::LargeUtf8 => make_scalar_function_inner( + array_expressions::string_to_array::, + )(args), other => { internal_err!( "Unsupported data type {other:?} for function string_to_array" @@ -756,10 +770,10 @@ pub fn create_physical_fun( } BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::starts_with::)(args) + make_scalar_function_inner(string_expressions::starts_with::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::starts_with::)(args) + make_scalar_function_inner(string_expressions::starts_with::)(args) } other => { internal_err!("Unsupported data type {other:?} for function starts_with") @@ -770,13 +784,13 @@ pub fn create_physical_fun( let func = invoke_if_unicode_expressions_feature_flag!( strpos, Int32Type, "strpos" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( strpos, Int64Type, "strpos" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function strpos"), }), @@ -784,21 +798,21 @@ pub fn create_physical_fun( DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function substr"), }), BuiltinScalarFunction::ToHex => Arc::new(|args| match args[0].data_type() { DataType::Int32 => { - make_scalar_function(string_expressions::to_hex::)(args) + make_scalar_function_inner(string_expressions::to_hex::)(args) } DataType::Int64 => { - make_scalar_function(string_expressions::to_hex::)(args) + make_scalar_function_inner(string_expressions::to_hex::)(args) } other => internal_err!("Unsupported data type {other:?} for function to_hex"), }), @@ -809,7 +823,7 @@ pub fn create_physical_fun( i32, "translate" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -817,7 +831,7 @@ pub fn create_physical_fun( i64, "translate" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => { internal_err!("Unsupported data type {other:?} for function translate") @@ -825,10 +839,10 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::Trim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function_inner(string_expressions::btrim::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function_inner(string_expressions::btrim::)(args) } other => internal_err!("Unsupported data type {other:?} for function trim"), }), @@ -849,10 +863,10 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::overlay::)(args) + make_scalar_function_inner(string_expressions::overlay::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::overlay::)(args) + make_scalar_function_inner(string_expressions::overlay::)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function overlay", @@ -860,12 +874,12 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::Levenshtein => { Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::levenshtein::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::levenshtein::)(args) - } + DataType::Utf8 => make_scalar_function_inner( + string_expressions::levenshtein::, + )(args), + DataType::LargeUtf8 => make_scalar_function_inner( + string_expressions::levenshtein::, + )(args), other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function levenshtein", ))), @@ -879,7 +893,7 @@ pub fn create_physical_fun( i32, "substr_index" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -887,7 +901,7 @@ pub fn create_physical_fun( i64, "substr_index" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function substr_index", @@ -901,7 +915,7 @@ pub fn create_physical_fun( Int32Type, "find_in_set" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -909,7 +923,7 @@ pub fn create_physical_fun( Int64Type, "find_in_set" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function find_in_set", @@ -3108,7 +3122,7 @@ mod tests { #[test] fn test_make_scalar_function() -> Result<()> { - let adapter_func = make_scalar_function(dummy_function); + let adapter_func = make_scalar_function_inner(dummy_function); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); let array_arg = ColumnarValue::Array( diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index b778fd86c24b..bdd272563e75 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -36,7 +36,9 @@ use hashbrown::HashMap; use regex::Regex; use std::sync::{Arc, OnceLock}; -use crate::functions::{make_scalar_function, make_scalar_function_with_hints, Hint}; +use crate::functions::{ + make_scalar_function_inner, make_scalar_function_with_hints, Hint, +}; /// Get the first argument from the given string array. /// @@ -401,7 +403,7 @@ pub fn specialize_regexp_replace( // If there are no specialized implementations, we'll fall back to the // generic implementation. - (_, _, _, _) => Ok(make_scalar_function(regexp_replace::)), + (_, _, _, _) => Ok(make_scalar_function_inner(regexp_replace::)), } } diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 9377501499e2..d9eda5d00d52 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -23,7 +23,6 @@ use crate::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use crate::protobuf; -use datafusion::physical_plan::functions::make_scalar_function; use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::{ create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility, @@ -117,7 +116,7 @@ impl Serializeable for Expr { vec![], Arc::new(arrow::datatypes::DataType::Null), Volatility::Immutable, - make_scalar_function(|_| unimplemented!()), + Arc::new(|_| unimplemented!()), ))) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2d38cfd400ad..cf991e524f27 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -35,7 +35,6 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; -use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::file_options::csv_writer::CsvWriterOptions; @@ -54,9 +53,9 @@ use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, - Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, - WindowUDFImpl, + ColumnarValue, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, + Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1632,9 +1631,12 @@ fn roundtrip_aggregate_udf() { #[test] fn roundtrip_scalar_udf() { - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - let scalar_fn = make_scalar_function(fn_impl); + let scalar_fn = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!("should be array") + }; + Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef)) + }); let udf = create_udf( "dummy", diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 8e0f75ce7d11..4f91713f488f 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -47,7 +47,6 @@ use datafusion::physical_plan::expressions::{ GetFieldAccessExpr, GetIndexedFieldExpr, NotExpr, NthValue, PhysicalSortExpr, Sum, }; use datafusion::physical_plan::filter::FilterExec; -use datafusion::physical_plan::functions::make_scalar_function; use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, @@ -73,8 +72,8 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{FileTypeWriterOptions, Result}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, Signature, SimpleAggregateUDF, - WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, Signature, + SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; @@ -569,9 +568,12 @@ fn roundtrip_scalar_udf() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - let scalar_fn = make_scalar_function(fn_impl); + let scalar_fn = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!("should be array") + }; + Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef)) + }); let udf = create_udf( "dummy", diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index 222d1a3a629c..7dd0333909ee 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -21,9 +21,8 @@ use arrow::array::ArrayRef; use arrow::datatypes::DataType; use datafusion::execution::FunctionRegistry; -use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::SessionContext; -use datafusion_expr::{col, create_udf, lit}; +use datafusion_expr::{col, create_udf, lit, ColumnarValue}; use datafusion_expr::{Expr, Volatility}; use datafusion_proto::bytes::Serializeable; @@ -226,9 +225,12 @@ fn roundtrip_deeply_nested() { /// return a `SessionContext` with a `dummy` function registered as a UDF fn context_with_udf() -> SessionContext { - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - let scalar_fn = make_scalar_function(fn_impl); + let scalar_fn = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!("should be array") + }; + Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef)) + }); let udf = create_udf( "dummy", diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index a5ce7ccb9fe0..889ccdcd66d4 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -28,8 +28,7 @@ use arrow::array::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; use datafusion::execution::context::SessionState; -use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility}; -use datafusion::physical_expr::functions::make_scalar_function; +use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use datafusion::{ @@ -356,9 +355,16 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { /// Create a UDF function named "example". See the `sample_udf.rs` example /// file for an explanation of the API. fn create_example_udf() -> ScalarUDF { - let adder = make_scalar_function(|args: &[ArrayRef]| { - let lhs = as_float64_array(&args[0]).expect("cast failed"); - let rhs = as_float64_array(&args[1]).expect("cast failed"); + let adder = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(lhs) = &args[0] else { + panic!("should be array") + }; + let ColumnarValue::Array(rhs) = &args[1] else { + panic!("should be array") + }; + + let lhs = as_float64_array(lhs).expect("cast failed"); + let rhs = as_float64_array(rhs).expect("cast failed"); let array = lhs .iter() .zip(rhs.iter()) @@ -367,7 +373,7 @@ fn create_example_udf() -> ScalarUDF { _ => None, }) .collect::(); - Ok(Arc::new(array) as ArrayRef) + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) }); create_udf( "example",