Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick is not null pushdown #268

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions datafusion/optimizer/src/filter_null_join_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@
//! [`FilterNullJoinKeys`] adds filters to join inputs when input isn't nullable

use crate::optimizer::ApplyOrder;
use crate::push_down_filter::on_lr_is_preserved;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, Result};
use datafusion_expr::utils::conjunction;
use datafusion_expr::{
logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable, LogicalPlan,
};
use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan};
use std::sync::Arc;

/// The FilterNullJoinKeys rule will identify inner joins with equi-join conditions
/// where the join key is nullable on one side and non-nullable on the other side
/// and then insert an `IsNotNull` filter on the nullable side since null values
/// The FilterNullJoinKeys rule will identify joins with equi-join conditions
/// where the join key is nullable and then insert an `IsNotNull` filter on the nullable side since null values
/// can never match.
#[derive(Default)]
pub struct FilterNullJoinKeys {}
Expand Down Expand Up @@ -59,21 +57,23 @@ impl OptimizerRule for FilterNullJoinKeys {
if !config.options().optimizer.filter_null_join_keys {
return Ok(Transformed::no(plan));
}

match plan {
LogicalPlan::Join(mut join) if join.join_type == JoinType::Inner => {
LogicalPlan::Join(mut join) if !join.on.is_empty() => {
let (left_preserved, right_preserved) =
on_lr_is_preserved(join.join_type);

let left_schema = join.left.schema();
let right_schema = join.right.schema();

let mut left_filters = vec![];
let mut right_filters = vec![];

for (l, r) in &join.on {
if l.nullable(left_schema)? {
if left_preserved && l.nullable(left_schema)? {
left_filters.push(l.clone());
}

if r.nullable(right_schema)? {
if right_preserved && r.nullable(right_schema)? {
right_filters.push(r.clone());
}
}
Expand Down Expand Up @@ -117,7 +117,7 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Column;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{col, lit, LogicalPlanBuilder};
use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder};

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected)
Expand All @@ -126,18 +126,41 @@ mod tests {
#[test]
fn left_nullable() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id")?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?;
let expected = "Inner Join: t1.optional_id = t2.id\
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1\
\n TableScan: t2";
assert_optimized_plan_equal(plan, expected)
}

#[test]
fn left_nullable_left_join() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?;
let expected = "Left Join: t1.optional_id = t2.id\
\n TableScan: t1\
\n TableScan: t2";
assert_optimized_plan_equal(plan, expected)
}

#[test]
fn left_nullable_left_join_reordered() -> Result<()> {
let (t_left, t_right) = test_tables()?;
// Note: order of tables is reversed
let plan =
build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?;
let expected = "Left Join: t2.id = t1.optional_id\
\n TableScan: t2\
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1";
assert_optimized_plan_equal(plan, expected)
}

#[test]
fn left_nullable_on_condition_reversed() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t2.id", "t1.optional_id")?;
let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?;
let expected = "Inner Join: t1.optional_id = t2.id\
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1\
Expand All @@ -148,7 +171,7 @@ mod tests {
#[test]
fn nested_join_multiple_filter_expr() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id")?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?;
let schema = Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("t1_id", DataType::UInt32, true),
Expand Down Expand Up @@ -252,11 +275,12 @@ mod tests {
right_table: LogicalPlan,
left_key: &str,
right_key: &str,
join_type: JoinType,
) -> Result<LogicalPlan> {
LogicalPlanBuilder::from(left_table)
.join(
right_table,
JoinType::Inner,
join_type,
(
vec![Column::from_qualified_name(left_key)],
vec![Column::from_qualified_name(right_key)],
Expand Down
55 changes: 30 additions & 25 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,38 +146,43 @@ pub struct PushDownFilter {}
/// there may be rows in the output that don't directly map to a row in the
/// right input (due to nulls filling where there is no match on the right).
///
/// This is important because we can always push down post-join filters to a preserved
/// side of the join, assuming the filter only references columns from that side. For the
/// non-preserved side it can be more tricky.
///
/// Returns a tuple of booleans - (left_preserved, right_preserved).
fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
/// - In a left join, the left side is preserved (we can push predicates) but
/// the right is not, because there may be rows in the output that don't
/// directly map to a row in the right input (due to nulls filling where there
/// is no match on the right).
pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((true, false)),
JoinType::Right => Ok((false, true)),
JoinType::Full => Ok((false, false)),
JoinType::Inner => (true, true),
JoinType::Left => (true, false),
JoinType::Right => (false, true),
JoinType::Full => (false, false),
// No columns from the right side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
JoinType::LeftSemi | JoinType::LeftAnti => (true, false),
// No columns from the left side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
JoinType::RightSemi | JoinType::RightAnti => (false, true),
}
}

/// For a given JOIN logical plan, determine whether each side of the join is preserved
/// in terms on join filtering.
/// Predicates from join filter can only be pushed to preserved join side.
fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
/// For a given JOIN type, determine whether each input of the join is preserved
/// for the join condition (`ON` clause filters).
///
/// It is only correct to push filters below a join for preserved inputs.
///
/// # Return Value
/// A tuple of booleans - (left_preserved, right_preserved).
///
/// See [`lr_is_preserved`] for a definition of "preserved".
pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((false, true)),
JoinType::Right => Ok((true, false)),
JoinType::Full => Ok((false, false)),
JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)),
JoinType::LeftAnti => Ok((false, true)),
JoinType::RightAnti => Ok((true, false)),
JoinType::Inner => (true, true),
JoinType::Left => (false, true),
JoinType::Right => (true, false),
JoinType::Full => (false, false),
JoinType::LeftSemi | JoinType::RightSemi => (true, true),
JoinType::LeftAnti => (false, true),
JoinType::RightAnti => (true, false),
}
}

Expand Down Expand Up @@ -395,7 +400,7 @@ fn push_down_all_join(
) -> Result<Transformed<LogicalPlan>> {
let is_inner_join = join.join_type == JoinType::Inner;
// Get pushable predicates from current optimizer state
let (left_preserved, right_preserved) = lr_is_preserved(join.join_type)?;
let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);

// The predicates can be divided to three categories:
// 1) can push through join to its children(left or right)
Expand Down Expand Up @@ -435,7 +440,7 @@ fn push_down_all_join(
}

if !on_filter.is_empty() {
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?;
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
for on in on_filter {
if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? {
left_push.push(on)
Expand Down
28 changes: 19 additions & 9 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ fn semi_join_with_join_filter() -> Result<()> {
let plan = test_sql(sql)?;
let expected = "Projection: test.col_utf8\
\n LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
\n SubqueryAlias: __correlated_sq_1\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32, col_uint32]";
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_uint32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
Expand All @@ -141,7 +143,8 @@ fn anti_join_with_join_filter() -> Result<()> {
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
\n SubqueryAlias: __correlated_sq_1\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32, col_uint32]";
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_uint32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
Expand All @@ -152,11 +155,13 @@ fn where_exists_distinct() -> Result<()> {
SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)";
let plan = test_sql(sql)?;
let expected = "LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32\
\n TableScan: test projection=[col_int32]\
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32]\
\n SubqueryAlias: __correlated_sq_1\
\n Aggregate: groupBy=[[t2.col_int32]], aggr=[[]]\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32]";
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
Expand All @@ -172,9 +177,12 @@ fn intersect() -> Result<()> {
\n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\
\n LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8\
\n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\
\n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_utf8]\
\n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_utf8]\
\n TableScan: test projection=[col_int32, col_utf8]\
\n TableScan: test projection=[col_int32, col_utf8]";
\n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_utf8]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
Expand Down Expand Up @@ -270,9 +278,11 @@ fn test_same_name_but_not_ambiguous() {
let expected = "LeftSemi Join: t1.col_int32 = t2.col_int32\
\n Aggregate: groupBy=[[t1.col_int32]], aggr=[[]]\
\n SubqueryAlias: t1\
\n TableScan: test projection=[col_int32]\
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32]\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32]";
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{plan:?}"));
}

Expand Down
Loading