diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 0dde7604cce2..d7e839824b3b 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -60,9 +60,8 @@ fn test_eq_with_coercion() { #[test] fn test_get_field() { - // field access Expr::field() requires a rewrite to work evaluate_expr_test( - col("props").field("a"), + get_field(col("props"), lit("a")), vec![ "+------------+", "| expr |", @@ -77,11 +76,8 @@ fn test_get_field() { #[test] fn test_nested_get_field() { - // field access Expr::field() requires a rewrite to work, test when it is - // not the root expression evaluate_expr_test( - col("props") - .field("a") + get_field(col("props"), lit("a")) .eq(lit("2021-02-02")) .or(col("id").eq(lit(1))), vec![ @@ -98,9 +94,8 @@ fn test_nested_get_field() { #[test] fn test_list() { - // list access also requires a rewrite to work evaluate_expr_test( - col("list").index(lit(1i64)), + array_element(col("list"), lit(1i64)), vec![ "+------+", "| expr |", "+------+", "| one |", "| two |", "| five |", "+------+", @@ -110,9 +105,8 @@ fn test_list() { #[test] fn test_list_range() { - // range access also requires a rewrite to work evaluate_expr_test( - col("list").range(lit(1i64), lit(2i64)), + array_slice(col("list"), lit(1i64), lit(2i64), None), vec![ "+--------------+", "| expr |", diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index 5280355a8224..a7aba78c1dbe 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -19,7 +19,6 @@ use crate::array_has::array_has_all; use crate::concat::{array_append, array_concat, array_prepend}; -use crate::extract::{array_element, array_slice}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; use datafusion_common::utils::list_ndims; @@ -27,8 +26,7 @@ use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::{BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Operator}; -use datafusion_functions::expr_fn::get_field; +use datafusion_expr::{BinaryExpr, Expr, Operator}; /// Rewrites expressions into function calls to array functions pub(crate) struct ArrayFunctionRewriter {} @@ -148,31 +146,6 @@ impl FunctionRewrite for ArrayFunctionRewriter { Transformed::yes(array_prepend(*left, *right)) } - Expr::GetIndexedField(GetIndexedField { - expr, - field: GetFieldAccess::NamedStructField { name }, - }) => { - let name = Expr::Literal(name); - Transformed::yes(get_field(*expr, name)) - } - - // expr[idx] ==> array_element(expr, idx) - Expr::GetIndexedField(GetIndexedField { - expr, - field: GetFieldAccess::ListIndex { key }, - }) => Transformed::yes(array_element(*expr, *key)), - - // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) - Expr::GetIndexedField(GetIndexedField { - expr, - field: - GetFieldAccess::ListRange { - start, - stop, - stride, - }, - }) => Transformed::yes(array_slice(*expr, *start, *stop, Some(*stride))), - _ => Transformed::no(expr), }; Ok(transformed) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 713ad6f72c24..d297b2e4df5b 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -19,9 +19,9 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::Field; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, - TableReference, + ScalarValue, TableReference, }; -use datafusion_expr::{Case, Expr}; +use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr}; use sqlparser::ast::{Expr as SQLExpr, Ident}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -133,7 +133,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ); } let nested_name = nested_names[0].to_string(); - Ok(Expr::Column(Column::from((qualifier, field))).field(nested_name)) + + let col = Expr::Column(Column::from((qualifier, field))); + if let Some(udf) = + self.context_provider.get_function_meta("get_field") + { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![col, lit(ScalarValue::from(nested_name))], + ))) + } else { + internal_err!("get_field not found") + } } // found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index ed5421edfbb0..6445c3f7a885 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -29,7 +29,7 @@ use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, - GetFieldAccess, GetIndexedField, Like, Literal, Operator, TryCast, + GetFieldAccess, Like, Literal, Operator, TryCast, }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -1019,10 +1019,48 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr }; - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(expr), - self.plan_indices(indices, schema, planner_context)?, - ))) + let field = self.plan_indices(indices, schema, planner_context)?; + match field { + GetFieldAccess::NamedStructField { name } => { + if let Some(udf) = self.context_provider.get_function_meta("get_field") { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![expr, lit(name)], + ))) + } else { + internal_err!("get_field not found") + } + } + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key } => { + if let Some(udf) = + self.context_provider.get_function_meta("array_element") + { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![expr, *key], + ))) + } else { + internal_err!("get_field not found") + } + } + // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) + GetFieldAccess::ListRange { + start, + stop, + stride, + } => { + if let Some(udf) = self.context_provider.get_function_meta("array_slice") + { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![expr, *start, *stop, *stride], + ))) + } else { + internal_err!("array_slice not found") + } + } + } } } diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 4b5f4d770a03..2dc00cbc5001 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2324,28 +2324,134 @@ host3 3.3 # can have an aggregate function with an inner CASE WHEN query TR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(( + case when t2.server_host is not null + then t2.server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; ---- host1 101 host2 202 host3 303 +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query error +select + t2.server['c3'] as host, + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server + from t1 + ) t2 + where t2.server['c3'] IS NOT NULL + group by t2.server['c3'] order by host; + # can have 2 projections with aggr(short_circuited), with different short-circuited expr query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(coalesce(server_load1)), + sum(( + case when t2.server_host is not null + then t2.server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c1'] as server_load1, + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; ---- host1 1.1 101 host2 2.2 202 host3 3.3 303 -# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN) +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query error +select + t2.server['c3'] as host, + sum(coalesce(server['c1'])), + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server, + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; + query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(( + case when t2.server_host is not null + then server_load1 + end + )), + sum(( + case when server_host is not null + then server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c1'] as server_load1, + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; ---- host1 1.1 101 host2 2.2 202 host3 3.3 303 +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query error +select + t2.server['c3'] as host, + sum(( + case when t2.server['c3'] is not null + then t2.server['c1'] + end + )), + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server + from t1 + ) t2 + where t2.server['c3'] IS NOT NULL + group by t2.server['c3'] order by host; + # can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce) query TRR select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;