Skip to content

Commit

Permalink
refactor: Change join_where semantics (#18640)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 10, 2024
1 parent 76a340b commit 45c8e96
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 70 deletions.
23 changes: 23 additions & 0 deletions crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,29 @@ impl Operator {
)
}

pub fn swap_operands(self) -> Self {
match self {
Operator::Eq => Operator::Eq,
Operator::Gt => Operator::Lt,
Operator::GtEq => Operator::LtEq,
Operator::LtEq => Operator::GtEq,
Operator::Or => Operator::Or,
Operator::LogicalAnd => Operator::LogicalAnd,
Operator::LogicalOr => Operator::LogicalOr,
Operator::Xor => Operator::Xor,
Operator::NotEq => Operator::NotEq,
Operator::EqValidity => Operator::EqValidity,
Operator::NotEqValidity => Operator::NotEqValidity,
Operator::Divide => Operator::Multiply,
Operator::Multiply => Operator::Divide,
Operator::And => Operator::And,
Operator::Plus => Operator::Minus,
Operator::Minus => Operator::Plus,
Operator::Lt => Operator::Gt,
_ => unimplemented!(),
}
}

pub fn is_arithmetic(&self) -> bool {
!(self.is_comparison())
}
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
options,
} => {
return join::resolve_join(
input_left,
input_right,
Either::Left(input_left),
Either::Left(input_right),
left_on,
right_on,
predicates,
Expand Down
190 changes: 155 additions & 35 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use arrow::legacy::error::PolarsResult;
use either::Either;

use super::*;
use crate::dsl::Expr;
Expand All @@ -16,8 +17,8 @@ fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> {
Ok(())
}
pub fn resolve_join(
input_left: Arc<DslPlan>,
input_right: Arc<DslPlan>,
input_left: Either<Arc<DslPlan>, Node>,
input_right: Either<Arc<DslPlan>, Node>,
left_on: Vec<Expr>,
right_on: Vec<Expr>,
predicates: Vec<Expr>,
Expand All @@ -26,7 +27,13 @@ pub fn resolve_join(
) -> PolarsResult<Node> {
if !predicates.is_empty() {
debug_assert!(left_on.is_empty() && right_on.is_empty());
return resolve_join_where(input_left, input_right, predicates, options, ctxt);
return resolve_join_where(
input_left.unwrap_left(),
input_right.unwrap_left(),
predicates,
options,
ctxt,
);
}

let owned = Arc::unwrap_or_clone;
Expand Down Expand Up @@ -62,10 +69,12 @@ pub fn resolve_join(
);
}

let input_left =
to_alp_impl(owned(input_left), ctxt).map_err(|e| e.context(failed_input!(join left)))?;
let input_right =
to_alp_impl(owned(input_right), ctxt).map_err(|e| e.context(failed_input!(join, right)))?;
let input_left = input_left.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(join left)))
})?;
let input_right = input_right.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(join right)))
})?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);
Expand Down Expand Up @@ -129,7 +138,6 @@ fn resolve_join_where(
ctxt: &mut DslConversionContext,
) -> PolarsResult<Node> {
check_join_keys(&predicates)?;

for e in &predicates {
let no_binary_comparisons = e
.into_iter()
Expand All @@ -138,15 +146,40 @@ fn resolve_join_where(
_ => false,
})
.count();
polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition")
polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition");
}
let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt)
.map_err(|e| e.context(failed_input!(join left)))?;
let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt)
.map_err(|e| e.context(failed_input!(join left)))?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt
.lp_arena
.get(input_right)
.schema(ctxt.lp_arena)
.into_owned();

let owned = |e: Arc<Expr>| (*e).clone();

// Partition to:
// We do a few things
// First we partition to:
// - IEjoin supported inequality predicates
// - equality predicates
// - remaining predicates
// And then decide to which join we dispatch.
// The remaining predicates will be applied as filter.

// What make things a bit complicated is that duplicate join names
// are referred to in the query with the name post-join, but on joins
// we refer to the names pre-join (e.g. without suffix). So there is some
// bookkeeping.
//
// - First we determine which side of the binary expression refers to the left and right table
// and make sure that lhs of the binary expr, maps to the lhs of the join tables and vice versa.
// Next we ensure the suffixes are removed when we partition.
//
// If a predicate has to be applied as post-join filter, we put the suffixes back if needed.
let mut ie_left_on = vec![];
let mut ie_right_on = vec![];
let mut ie_op = vec![];
Expand All @@ -166,68 +199,150 @@ fn resolve_join_where(
}
}

fn rename_expr(e: Expr, old: &str, new: &str) -> Expr {
e.map_expr(|e| match e {
Expr::Column(name) if name.as_str() == old => Expr::Column(new.into()),
e => e,
})
}

fn determine_order_and_pre_join_names(
left: Expr,
op: Operator,
right: Expr,
schema_left: &Schema,
schema_right: &Schema,
suffix: &str,
) -> PolarsResult<(Expr, Operator, Expr)> {
let left_names = expr_to_leaf_column_names_iter(&left).collect::<PlHashSet<_>>();
let right_names = expr_to_leaf_column_names_iter(&right).collect::<PlHashSet<_>>();

// All left should be in the left schema.
let (left_names, right_names, left, op, mut right) =
if !left_names.iter().all(|n| schema_left.contains(n)) {
// If all right names are in left schema -> swap
if right_names.iter().all(|n| schema_left.contains(n)) {
(right_names, left_names, right, op.swap_operands(), left)
} else {
polars_bail!(InvalidOperation: "got ambiguous column names in 'join_where'")
}
} else {
(left_names, right_names, left, op, right)
};
for name in &left_names {
polars_ensure!(!right_names.contains(name.as_str()), InvalidOperation: "got ambiguous column names in 'join_where'\n\n\
Note that you should refer to the column names as they are post-join operation.")
}

// Now we know left belongs to the left schema, rhs suffixes are dealt with.
for post_join_name in right_names {
if let Some(pre_join_name) = post_join_name.strip_suffix(suffix) {
// Name is both sides, so a suffix will be added by the join.
// We rename
if schema_right.contains(pre_join_name) && schema_left.contains(pre_join_name) {
right = rename_expr(right, &post_join_name, pre_join_name);
}
}
}
Ok((left, op, right))
}

// Make it a binary comparison and ensure the columns refer to post join names.
fn to_binary_post_join(
l: Expr,
op: Operator,
mut r: Expr,
schema_right: &Schema,
suffix: &str,
) -> Expr {
let names = expr_to_leaf_column_names_iter(&r).collect::<Vec<_>>();
for pre_join_name in &names {
if !schema_right.contains(pre_join_name) {
let post_join_name = _join_suffix_name(pre_join_name, suffix);
r = rename_expr(r, pre_join_name, post_join_name.as_str());
}
}

Expr::BinaryExpr {
left: Arc::from(l),
op,
right: Arc::from(r),
}
}

let suffix = options.args.suffix().clone();
for pred in predicates.into_iter() {
let Expr::BinaryExpr { left, op, right } = pred.clone() else {
polars_bail!(InvalidOperation: "can only join on binary expressions")
};
polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate");
let (left, op, right) = determine_order_and_pre_join_names(
owned(left),
op,
owned(right),
&schema_left,
&schema_right,
&suffix,
)?;

if let Some(ie_op_) = to_inequality_operator(&op) {
// We already have an IEjoin or an Inner join, push to remaining
if ie_op.len() >= 2 || !eq_right_on.is_empty() {
remaining_preds.push(Expr::BinaryExpr { left, op, right })
remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix))
} else {
ie_left_on.push(owned(left));
ie_right_on.push(owned(right));
ie_left_on.push(left);
ie_right_on.push(right);
ie_op.push(ie_op_)
}
} else if matches!(op, Operator::Eq) {
eq_left_on.push(owned(left));
eq_right_on.push(owned(right));
eq_left_on.push(left);
eq_right_on.push(right);
} else {
remaining_preds.push(pred);
remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix));
}
}

// Now choose a primary join and do the remaining predicates as filters
fn to_binary(l: Expr, op: Operator, r: Expr) -> Expr {
Expr::BinaryExpr {
left: Arc::from(l),
op,
right: Arc::from(r),
}
}
// Add the ie predicates to the remaining predicates buffer so that they will be executed in the
// filter node.
fn ie_predicates_to_remaining(
remaining_preds: &mut Vec<Expr>,
ie_left_on: Vec<Expr>,
ie_right_on: Vec<Expr>,
ie_op: Vec<InequalityOperator>,
schema_right: &Schema,
suffix: &str,
) {
for ((l, op), r) in ie_left_on
.into_iter()
.zip(ie_op.into_iter())
.zip(ie_right_on.into_iter())
{
remaining_preds.push(to_binary(l, op.into(), r))
remaining_preds.push(to_binary_post_join(l, op.into(), r, schema_right, suffix))
}
}

let join_node = if !eq_left_on.is_empty() {
// We found one or more equality predicates. Go into a default equi join
// as those are cheapest on avg.
let join_node = resolve_join(
input_left,
input_right,
Either::Right(input_left),
Either::Right(input_right),
eq_left_on,
eq_right_on,
vec![],
options.clone(),
ctxt,
)?;

ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op);
ie_predicates_to_remaining(
&mut remaining_preds,
ie_left_on,
ie_right_on,
ie_op,
&schema_right,
&suffix,
);
join_node
}
// TODO! once we support single IEjoin predicates, we must add a branch for the singe ie_pred case.
Expand All @@ -240,8 +355,8 @@ fn resolve_join_where(
});

let join_node = resolve_join(
input_left,
input_right,
Either::Right(input_left),
Either::Right(input_right),
ie_left_on[..2].to_vec(),
ie_right_on[..2].to_vec(),
vec![],
Expand All @@ -258,7 +373,7 @@ fn resolve_join_where(
let r = ie_left_on.pop().unwrap();
let op = ie_op.pop().unwrap();

remaining_preds.push(to_binary(l, op.into(), r))
remaining_preds.push(to_binary_post_join(l, op.into(), r, &schema_right, &suffix))
}
join_node
} else {
Expand All @@ -268,16 +383,23 @@ fn resolve_join_where(
opts.args.how = JoinType::Cross;

let join_node = resolve_join(
input_left,
input_right,
Either::Right(input_left),
Either::Right(input_right),
vec![],
vec![],
vec![],
options.clone(),
ctxt,
)?;
// TODO: This can be removed once we support the single IEjoin.
ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op);
ie_predicates_to_remaining(
&mut remaining_preds,
ie_left_on,
ie_right_on,
ie_op,
&schema_right,
&suffix,
);
join_node
};

Expand All @@ -301,8 +423,6 @@ fn resolve_join_where(
.schema(ctxt.lp_arena)
.into_owned();

let suffix = options.args.suffix();

let mut last_node = join_node;

// Ensure that the predicates use the proper suffix
Expand Down
13 changes: 2 additions & 11 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7118,20 +7118,11 @@ def join_where(
DataFrame to join with.
*predicates
(In)Equality condition to join the two table on.
The left `pl.col(..)` will refer to the left table
and the right `pl.col(..)`
to the right table.
For example: `pl.col("time") >= pl.col("duration")`
When a column name occurs in both tables, the proper suffix must
be applied in the predicate.
suffix
Suffix to append to columns with a duplicate name.
Notes
-----
This method is strict about its equality expressions.
Only 1 equality expression is allowed per predicate, where
the lhs `pl.col` refers to the left table in the join, and the
rhs `pl.col` refers to the right table.
Examples
--------
>>> east = pl.DataFrame(
Expand Down
Loading

0 comments on commit 45c8e96

Please sign in to comment.