Skip to content

Commit

Permalink
Remove ScalarFunctionDefinition (#10325)
Browse files Browse the repository at this point in the history
* Remove ScalarFunctionDefinition

* Fix test

* rename func_def to func

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
lewiszlw and alamb committed May 7, 2024
1 parent 5146f44 commit 742e3c5
Show file tree
Hide file tree
Showing 22 changed files with 124 additions and 231 deletions.
18 changes: 7 additions & 11 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use log::{debug, trace};

use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{Column, DFSchema, DataFusionError};
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
use datafusion_expr::{Expr, Volatility};
use datafusion_physical_expr::create_physical_expr;
use object_store::path::Path;
use object_store::{ObjectMeta, ObjectStore};
Expand Down Expand Up @@ -89,16 +89,12 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
| Expr::Case { .. } => Ok(TreeNodeRecursion::Continue),

Expr::ScalarFunction(scalar_function) => {
match &scalar_function.func_def {
ScalarFunctionDefinition::UDF(fun) => {
match fun.signature().volatility {
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(TreeNodeRecursion::Stop)
}
}
match scalar_function.func.signature().volatility {
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(TreeNodeRecursion::Stop)
}
}
}
Expand Down
19 changes: 5 additions & 14 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1301,8 +1301,7 @@ mod tests {
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{
ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl,
Signature, Volatility,
ColumnarValue, Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_physical_expr::expressions::{
BinaryExpr, CaseExpr, CastExpr, NegativeExpr,
Expand Down Expand Up @@ -1363,9 +1362,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
DummyUDF::new(),
))),
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -1431,9 +1428,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
DummyUDF::new(),
))),
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -1502,9 +1497,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
DummyUDF::new(),
))),
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -1570,9 +1563,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
DummyUDF::new(),
))),
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b_new", 1)),
Expand Down
54 changes: 9 additions & 45 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ use std::sync::Arc;
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
use crate::utils::expr_to_columns;
use crate::window_frame;
use crate::{
aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator,
Signature,
};
use crate::{window_frame, Volatility};

use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
Expand Down Expand Up @@ -399,60 +399,26 @@ impl Between {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
/// Defines which implementation of a function for DataFusion to call.
pub enum ScalarFunctionDefinition {
/// Resolved to a user defined function
UDF(Arc<crate::ScalarUDF>),
}

/// ScalarFunction expression invokes a built-in scalar function
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarFunction {
/// The function
pub func_def: ScalarFunctionDefinition,
pub func: Arc<crate::ScalarUDF>,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
}

impl ScalarFunction {
// return the Function's name
pub fn name(&self) -> &str {
self.func_def.name()
}
}

impl ScalarFunctionDefinition {
/// Function's name for display
pub fn name(&self) -> &str {
match self {
ScalarFunctionDefinition::UDF(udf) => udf.name(),
}
}

/// Whether this function is volatile, i.e. whether it can return different results
/// when evaluated multiple times with the same input.
pub fn is_volatile(&self) -> Result<bool> {
match self {
ScalarFunctionDefinition::UDF(udf) => {
Ok(udf.signature().volatility == crate::Volatility::Volatile)
}
}
self.func.name()
}
}

impl ScalarFunction {
/// Create a new ScalarFunction expression with a user-defined function (UDF)
pub fn new_udf(udf: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
Self {
func_def: ScalarFunctionDefinition::UDF(udf),
args,
}
}

/// Create a new ScalarFunction expression with a user-defined function (UDF)
pub fn new_func_def(func_def: ScalarFunctionDefinition, args: Vec<Expr>) -> Self {
Self { func_def, args }
Self { func: udf, args }
}
}

Expand Down Expand Up @@ -1299,7 +1265,7 @@ impl Expr {
/// results when evaluated multiple times with the same input.
pub fn is_volatile(&self) -> Result<bool> {
self.exists(|expr| {
Ok(matches!(expr, Expr::ScalarFunction(func) if func.func_def.is_volatile()?))
Ok(matches!(expr, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile ))
})
}

Expand Down Expand Up @@ -1334,9 +1300,7 @@ impl Expr {
/// and thus any side effects (like divide by zero) may not be encountered
pub fn short_circuits(&self) -> bool {
match self {
Expr::ScalarFunction(ScalarFunction { func_def, .. }) => {
matches!(func_def, ScalarFunctionDefinition::UDF(fun) if fun.short_circuits())
}
Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(),
Expr::BinaryExpr(BinaryExpr { op, .. }) => {
matches!(op, Operator::And | Operator::Or)
}
Expand Down Expand Up @@ -2071,7 +2035,7 @@ mod test {
}

#[test]
fn test_is_volatile_scalar_func_definition() {
fn test_is_volatile_scalar_func() {
// UDF
#[derive(Debug)]
struct TestScalarUDF {
Expand Down Expand Up @@ -2100,7 +2064,7 @@ mod test {
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
}));
assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
assert_ne!(udf.signature().volatility, Volatility::Volatile);

let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
signature: Signature::uniform(
Expand All @@ -2109,7 +2073,7 @@ mod test {
Volatility::Volatile,
),
}));
assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
assert_eq!(udf.signature().volatility, Volatility::Volatile);
}

use super::*;
Expand Down
16 changes: 6 additions & 10 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use super::{Between, Expr, Like};
use crate::expr::{
AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast,
GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction,
ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction,
Sort, TryCast, Unnest, WindowFunction,
};
use crate::field_util::GetFieldAccessSchema;
use crate::type_coercion::binary::get_result_type;
Expand Down Expand Up @@ -133,30 +133,26 @@ impl ExprSchemable for Expr {
}
}
}
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
Expr::ScalarFunction(ScalarFunction { func, args }) => {
let arg_data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
match func_def {
ScalarFunctionDefinition::UDF(fun) => {
// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
data_types(&arg_data_types, fun.signature()).map_err(|_| {
data_types(&arg_data_types, func.signature()).map_err(|_| {
plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
fun.name(),
fun.signature().clone(),
func.name(),
func.signature().clone(),
&arg_data_types,
)
)
})?;

// perform additional function arguments validation (due to limited
// expressiveness of `TypeSignature`), then infer return type
Ok(fun.return_type_from_exprs(args, schema, &arg_data_types)?)
}
}
Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
let data_types = args
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub use built_in_window_function::BuiltInWindowFunction;
pub use columnar_value::ColumnarValue;
pub use expr::{
Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet,
Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition,
Like, TryCast, WindowFunctionDefinition,
};
pub use expr_fn::*;
pub use expr_schema::ExprSchemable;
Expand Down
12 changes: 6 additions & 6 deletions datafusion/expr/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use crate::expr::{
AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case,
Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder,
ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction,
ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
};
use crate::{Expr, GetFieldAccess};

Expand Down Expand Up @@ -281,11 +281,11 @@ impl TreeNode for Expr {
nulls_first,
}) => transform_box(expr, &mut f)?
.update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))),
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
transform_vec(args, &mut f)?.map_data(|new_args| match func_def {
ScalarFunctionDefinition::UDF(fun) => {
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_args)))
}
Expr::ScalarFunction(ScalarFunction { func, args }) => {
transform_vec(args, &mut f)?.map_data(|new_args| {
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
func, new_args,
)))
})?
}
Expr::WindowFunction(WindowFunction {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/functions-array/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,20 @@ impl FunctionRewrite for ArrayFunctionRewriter {
/// Returns true if expr is a function call to the specified named function.
/// Returns false otherwise.
fn is_func(expr: &Expr, func_name: &str) -> bool {
let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else {
let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
return false;
};

func_def.name() == func_name
func.name() == func_name
}

/// Returns true if expr is a function call with one of the specified names
fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool {
let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else {
let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
return false;
};

func_names.contains(&func_def.name())
func_names.contains(&func.name())
}

/// returns Some(col) if this is Expr::Column
Expand Down
19 changes: 5 additions & 14 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ use datafusion_common::{
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
lit, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition,
};
use datafusion_expr::{lit, ColumnarValue, Expr, FuncMonotonicity, ScalarUDF};

use arrow::array::{ArrayRef, Float32Array, Float64Array};
use datafusion_expr::TypeSignature::*;
Expand Down Expand Up @@ -178,8 +176,8 @@ impl ScalarUDFImpl for LogFunc {
&info.get_data_type(&base)?,
)?)))
}
Expr::ScalarFunction(ScalarFunction { func_def, mut args })
if is_pow(&func_def) && args.len() == 2 && base == args[0] =>
Expr::ScalarFunction(ScalarFunction { func, mut args })
if is_pow(&func) && args.len() == 2 && base == args[0] =>
{
let b = args.pop().unwrap(); // length checked above
Ok(ExprSimplifyResult::Simplified(b))
Expand Down Expand Up @@ -207,15 +205,8 @@ impl ScalarUDFImpl for LogFunc {
}

/// Returns true if the function is `PowerFunc`
fn is_pow(func_def: &ScalarFunctionDefinition) -> bool {
match func_def {
ScalarFunctionDefinition::UDF(fun) => fun
.as_ref()
.inner()
.as_any()
.downcast_ref::<PowerFunc>()
.is_some(),
}
fn is_pow(func: &ScalarUDF) -> bool {
func.inner().as_any().downcast_ref::<PowerFunc>().is_some()
}

#[cfg(test)]
Expand Down
17 changes: 5 additions & 12 deletions datafusion/functions/src/math/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datafusion_common::{
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionDefinition};
use datafusion_expr::{ColumnarValue, Expr, ScalarUDF};

use arrow::array::{ArrayRef, Float64Array, Int64Array};
use datafusion_expr::TypeSignature::*;
Expand Down Expand Up @@ -140,8 +140,8 @@ impl ScalarUDFImpl for PowerFunc {
Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => {
Ok(ExprSimplifyResult::Simplified(base))
}
Expr::ScalarFunction(ScalarFunction { func_def, mut args })
if is_log(&func_def) && args.len() == 2 && base == args[0] =>
Expr::ScalarFunction(ScalarFunction { func, mut args })
if is_log(&func) && args.len() == 2 && base == args[0] =>
{
let b = args.pop().unwrap(); // length checked above
Ok(ExprSimplifyResult::Simplified(b))
Expand All @@ -152,15 +152,8 @@ impl ScalarUDFImpl for PowerFunc {
}

/// Return true if this function call is a call to `Log`
fn is_log(func_def: &ScalarFunctionDefinition) -> bool {
match func_def {
ScalarFunctionDefinition::UDF(fun) => fun
.as_ref()
.inner()
.as_any()
.downcast_ref::<LogFunc>()
.is_some(),
}
fn is_log(func: &ScalarUDF) -> bool {
func.inner().as_any().downcast_ref::<LogFunc>().is_some()
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit 742e3c5

Please sign in to comment.