From e473953707923103901da91b34ec72a0d9253345 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sat, 30 Sep 2023 00:35:10 +0800 Subject: [PATCH] minor: revert parsing precedence between Aggr and UDAF (#7682) * minor: revert parsing precedence between Aggr and UDAF Signed-off-by: Ruihang Xia * add unit test Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- .../user_defined/user_defined_aggregates.rs | 37 +++++++++++++++++-- datafusion/sql/src/expr/function.rs | 18 ++++----- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 64547bbdfa36..3b7b4d0e87b7 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -169,6 +169,37 @@ async fn test_udaf_returning_struct_subquery() { assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); } +#[tokio::test] +async fn test_udaf_shadows_builtin_fn() { + let TestContext { + mut ctx, + test_state, + } = TestContext::new(); + let sql = "SELECT sum(arrow_cast(time, 'Int64')) from t"; + + // compute with builtin `sum` aggregator + let expected = [ + "+-------------+", + "| SUM(t.time) |", + "+-------------+", + "| 19000 |", + "+-------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); + + // Register `TimeSum` with name `sum`. This will shadow the builtin one + let sql = "SELECT sum(time) from t"; + TimeSum::register(&mut ctx, test_state.clone(), "sum"); + let expected = [ + "+----------------------------+", + "| sum(t.time) |", + "+----------------------------+", + "| 1970-01-01T00:00:00.000019 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); +} + async fn execute(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } @@ -214,7 +245,7 @@ impl TestContext { // Tell DataFusion about the "first" function FirstSelector::register(&mut ctx); // Tell DataFusion about the "time_sum" function - TimeSum::register(&mut ctx, Arc::clone(&test_state)); + TimeSum::register(&mut ctx, Arc::clone(&test_state), "time_sum"); Self { ctx, test_state } } @@ -281,7 +312,7 @@ impl TimeSum { Self { sum: 0, test_state } } - fn register(ctx: &mut SessionContext, test_state: Arc) { + fn register(ctx: &mut SessionContext, test_state: Arc, name: &str) { let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None); // Returns the same type as its input @@ -301,8 +332,6 @@ impl TimeSum { let accumulator: AccumulatorFactoryFunction = Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); - let name = "time_sum"; - let time_sum = AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 05f80fcfafa9..3861b4848d9b 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -124,6 +124,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(expr); } } else { + // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function + if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { + let args = + self.function_args_to_expr(function.args, schema, planner_context)?; + return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( + fm, args, None, None, + ))); + } + // next, aggregate built-ins if let Ok(fun) = AggregateFunction::from_str(&name) { let distinct = function.distinct; @@ -141,15 +150,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))); }; - // User defined aggregate functions (UDAF) - if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; - return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( - fm, args, None, None, - ))); - } - // Special case arrow_cast (as its type is dependent on its argument value) if name == ARROW_CAST_NAME { let args =