Skip to content

Commit

Permalink
Support List type coercion for CASE-WHEN-THEN expression (#12490)
Browse files Browse the repository at this point in the history
* support list type coercion

* add planing and sql tests

* clippy

* support to compare nested type for case-when expression

* simplify the macro rules

* fix the FixedSizeList type coercion and add tests

* add test for THEN-ELSE
  • Loading branch information
goldmedal committed Sep 21, 2024
1 parent 515a64e commit 244ce5a
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 2 deletions.
73 changes: 73 additions & 0 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,22 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(List(_), List(_)) => Some(lhs_type.clone()),
(LargeList(_), List(_)) => Some(lhs_type.clone()),
(List(_), LargeList(_)) => Some(rhs_type.clone()),
(LargeList(_), LargeList(_)) => Some(lhs_type.clone()),
(List(_), FixedSizeList(_, _)) => Some(lhs_type.clone()),
(FixedSizeList(_, _), List(_)) => Some(rhs_type.clone()),
// Coerce to the left side FixedSizeList type if the list lengths are the same,
// otherwise coerce to list with the left type for dynamic length
(FixedSizeList(lf, ls), FixedSizeList(_, rs)) => {
if ls == rs {
Some(lhs_type.clone())
} else {
Some(List(Arc::clone(lf)))
}
}
(LargeList(_), FixedSizeList(_, _)) => Some(lhs_type.clone()),
(FixedSizeList(_, _), LargeList(_)) => Some(rhs_type.clone()),
_ => None,
}
}
Expand Down Expand Up @@ -1906,6 +1922,63 @@ mod tests {
DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into()))
);

// list
let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
test_coercion_binary_rule!(
DataType::List(Arc::clone(&inner_field)),
DataType::List(Arc::clone(&inner_field)),
Operator::Eq,
DataType::List(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::List(Arc::clone(&inner_field)),
DataType::LargeList(Arc::clone(&inner_field)),
Operator::Eq,
DataType::LargeList(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::LargeList(Arc::clone(&inner_field)),
DataType::List(Arc::clone(&inner_field)),
Operator::Eq,
DataType::LargeList(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::LargeList(Arc::clone(&inner_field)),
DataType::LargeList(Arc::clone(&inner_field)),
Operator::Eq,
DataType::LargeList(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
Operator::Eq,
DataType::FixedSizeList(Arc::clone(&inner_field), 10)
);
test_coercion_binary_rule!(
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
DataType::LargeList(Arc::clone(&inner_field)),
Operator::Eq,
DataType::LargeList(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::LargeList(Arc::clone(&inner_field)),
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
Operator::Eq,
DataType::LargeList(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::List(Arc::clone(&inner_field)),
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
Operator::Eq,
DataType::List(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
DataType::List(Arc::clone(&inner_field)),
Operator::Eq,
DataType::List(Arc::clone(&inner_field))
);

// TODO add other data type
Ok(())
}
Expand Down
180 changes: 180 additions & 0 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,186 @@ mod test {
Ok(())
}

macro_rules! test_case_expression {
($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
let case = Case {
expr: $expr.map(|e| Box::new(col(e))),
when_then_expr: $when_then,
else_expr: None,
};

let expected =
cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);

let actual = coerce_case_expression(case, &$schema)?;
assert_eq!(expected, actual);
};
}

#[test]
fn tes_case_when_list() -> Result<()> {
let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
let schema = Arc::new(DFSchema::from_unqualified_fields(
vec![
Field::new(
"large_list",
DataType::LargeList(Arc::clone(&inner_field)),
true,
),
Field::new(
"fixed_list",
DataType::FixedSizeList(Arc::clone(&inner_field), 3),
true,
),
Field::new("list", DataType::List(inner_field), true),
]
.into(),
std::collections::HashMap::new(),
)?);

test_case_expression!(
Some("list"),
vec![(Box::new(col("large_list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
Some("large_list"),
vec![(Box::new(col("list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
Some("list"),
vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
Some("fixed_list"),
vec![(Box::new(col("list")), Box::new(lit("1")))],
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
Some("fixed_list"),
vec![(Box::new(col("large_list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
Some("large_list"),
vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);
Ok(())
}

#[test]
fn test_then_else_list() -> Result<()> {
let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
let schema = Arc::new(DFSchema::from_unqualified_fields(
vec![
Field::new("boolean", DataType::Boolean, true),
Field::new(
"large_list",
DataType::LargeList(Arc::clone(&inner_field)),
true,
),
Field::new(
"fixed_list",
DataType::FixedSizeList(Arc::clone(&inner_field), 3),
true,
),
Field::new("list", DataType::List(inner_field), true),
]
.into(),
std::collections::HashMap::new(),
)?);

// large list and list
test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("large_list"))),
(Box::new(col("boolean")), Box::new(col("list")))
],
DataType::Boolean,
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);

test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("list"))),
(Box::new(col("boolean")), Box::new(col("large_list")))
],
DataType::Boolean,
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);

// fixed list and list
test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("fixed_list"))),
(Box::new(col("boolean")), Box::new(col("list")))
],
DataType::Boolean,
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);

test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("list"))),
(Box::new(col("boolean")), Box::new(col("fixed_list")))
],
DataType::Boolean,
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);

// fixed list and large list
test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("fixed_list"))),
(Box::new(col("boolean")), Box::new(col("large_list")))
],
DataType::Boolean,
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);

test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("large_list"))),
(Box::new(col("boolean")), Box::new(col("fixed_list")))
],
DataType::Boolean,
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);
Ok(())
}

#[test]
fn interval_plus_timestamp() -> Result<()> {
// SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp;
Expand Down
10 changes: 8 additions & 2 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;

use arrow::array::*;
use arrow::compute::kernels::cmp::eq;
use arrow::compute::kernels::zip::zip;
use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
use arrow::datatypes::{DataType, Schema};
Expand All @@ -33,6 +32,7 @@ use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarV
use datafusion_expr::ColumnarValue;

use super::{Column, Literal};
use datafusion_physical_expr_common::datum::compare_with_eq;
use itertools::Itertools;

type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
Expand Down Expand Up @@ -204,7 +204,13 @@ impl CaseExpr {
.evaluate_selection(batch, &remainder)?;
let when_value = when_value.into_array(batch.num_rows())?;
// build boolean array representing which rows match the "when" value
let when_match = eq(&when_value, &base_value)?;
let when_match = compare_with_eq(
&when_value,
&base_value,
// The types of case and when expressions will be coerced to match.
// We only need to check if the base_value is nested.
base_value.data_type().is_nested(),
)?;
// Treat nulls as false
let when_match = match when_match.null_count() {
0 => Cow::Borrowed(&when_match),
Expand Down
Loading

0 comments on commit 244ce5a

Please sign in to comment.