From 2b0542f530c95a3fb11b075695adf309f11b2695 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Mon, 22 Jul 2024 20:35:42 +0800 Subject: [PATCH] fix: CASE with NULL (#11542) * fix: CASE with NULL * chore: Add tests * chore * chore: Fix CI * chore: Support all types are NULL * chore: Fix CI * chore: add more tests * fix: Return first non-null type in then exprs * chore: Fix CI * Update datafusion/expr/src/expr_schema.rs Co-authored-by: Jonah Gao * Update datafusion/expr/src/expr_schema.rs Co-authored-by: Jonah Gao --------- Co-authored-by: Jonah Gao --- datafusion/expr/src/expr_schema.rs | 12 +++++++- .../sqllogictest/test_files/aggregate.slt | 28 +++++++++++++++++++ datafusion/sqllogictest/test_files/scalar.slt | 8 +++--- datafusion/sqllogictest/test_files/select.slt | 27 ++++++++++++++++++ 4 files changed, 70 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 1df5d6c4d736..5e0571f712ee 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -112,7 +112,17 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.data_type()), - Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), + Expr::Case(case) => { + for (_, then_expr) in &case.when_then_expr { + let then_type = then_expr.get_type(schema)?; + if !then_type.is_null() { + return Ok(then_type); + } + } + case.else_expr + .as_ref() + .map_or(Ok(DataType::Null), |e| e.get_type(schema)) + } Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::Unnest(Unnest { expr }) => { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index d0f7f2d9ac7a..bb5ce1150a58 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5418,6 +5418,34 @@ SELECT LAST_VALUE(column1 ORDER BY column2 DESC) IGNORE NULLS FROM t; statement ok DROP TABLE t; +# Test for CASE with NULL in aggregate function +statement ok +CREATE TABLE example(data double precision); + +statement ok +INSERT INTO example VALUES (1), (2), (NULL), (4); + +query RR +SELECT + sum(CASE WHEN data is NULL THEN NULL ELSE data+1 END) as then_null, + sum(CASE WHEN data is NULL THEN data+1 ELSE NULL END) as else_null +FROM example; +---- +10 NULL + +query R +SELECT + CASE data WHEN 1 THEN NULL WHEN 2 THEN 3.3 ELSE NULL END as case_null +FROM example; +---- +NULL +3.3 +NULL +NULL + +statement ok +drop table example; + # Test Convert FirstLast optimizer rule statement ok CREATE EXTERNAL TABLE convert_first_last_table ( diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 48f94fc080a4..ff9afa94f40a 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1238,27 +1238,27 @@ SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END bar # case_expr_with_null() -query ? +query I select case when b is null then null else b end from (select a,b from (values (1,null),(2,3)) as t (a,b)) a; ---- NULL 3 -query ? +query I select case when b is null then null else b end from (select a,b from (values (1,1),(2,3)) as t (a,b)) a; ---- 1 3 # case_expr_with_nulls() -query ? +query I select case when b is null then null when b < 3 then null when b >=3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a ---- NULL NULL 4 -query ? +query I select case b when 1 then null when 2 then null when 3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a; ---- NULL diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 03426dec874f..6884efc07e15 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -613,6 +613,33 @@ END; ---- 2 +# select case when type is null +query I +select CASE + WHEN NULL THEN 1 + ELSE 2 +END; +---- +2 + +# select case then type is null +query I +select CASE + WHEN 10 > 5 THEN NULL + ELSE 2 +END; +---- +NULL + +# select case else type is null +query I +select CASE + WHEN 10 = 5 THEN 1 + ELSE NULL +END; +---- +NULL + # Binary Expression for LargeUtf8 # issue: https://github.com/apache/datafusion/issues/5893 statement ok