Skip to content

Commit

Permalink
feat: Implement single inequality joins for join_where (pola-rs#18727)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreeve authored and nameexhaustion committed Sep 16, 2024
1 parent f52cb09 commit b5e2fec
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 27 deletions.
253 changes: 242 additions & 11 deletions crates/polars-ops/src/frame/join/iejoin/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
mod filtered_bit_array;
mod l1_l2;

use std::cmp::min;

use filtered_bit_array::FilteredBitArray;
use l1_l2::*;
use polars_core::chunked_array::ChunkedArray;
use polars_core::datatypes::{IdxCa, NumericNative, PolarsNumericType};
use polars_core::frame::DataFrame;
use polars_core::prelude::*;
use polars_core::series::IsSorted;
use polars_core::utils::{_set_partition_size, split};
use polars_core::{with_match_physical_numeric_polars_type, POOL};
use polars_error::{polars_err, PolarsResult};
use polars_utils::binary_search::ExponentialSearch;
use polars_utils::itertools::Itertools;
use polars_utils::slice::GetSaferUnchecked;
use polars_utils::total_ord::TotalEq;
use polars_utils::total_ord::{TotalEq, TotalOrd};
use polars_utils::IdxSize;
use rayon::prelude::*;
#[cfg(feature = "serde")]
Expand All @@ -40,7 +43,7 @@ impl InequalityOperator {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct IEJoinOptions {
pub operator1: InequalityOperator,
pub operator2: InequalityOperator,
pub operator2: Option<InequalityOperator>,
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -61,10 +64,7 @@ fn ie_join_impl_t<T: PolarsNumericType>(
let mut left_row_idx: Vec<IdxSize> = vec![];
let mut right_row_idx: Vec<IdxSize> = vec![];

let slice_end = match slice {
Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)),
_ => None,
};
let slice_end = slice_end_index(slice);
let mut match_count = 0;

let ca: &ChunkedArray<T> = x.as_ref().as_ref();
Expand Down Expand Up @@ -130,6 +130,78 @@ fn ie_join_impl_t<T: PolarsNumericType>(
Ok((left_row_idx, right_row_idx))
}

fn piecewise_merge_join_impl_t<T, P>(
slice: Option<(i64, usize)>,
left_order: Option<&[IdxSize]>,
right_order: Option<&[IdxSize]>,
left_ordered: Series,
right_ordered: Series,
mut pred: P,
) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)>
where
T: PolarsNumericType,
P: FnMut(&T::Native, &T::Native) -> bool,
{
let slice_end = slice_end_index(slice);

let mut left_row_idx: Vec<IdxSize> = vec![];
let mut right_row_idx: Vec<IdxSize> = vec![];

let left_ca: &ChunkedArray<T> = left_ordered.as_ref().as_ref();
let right_ca: &ChunkedArray<T> = right_ordered.as_ref().as_ref();

debug_assert!(left_order.is_none_or(|order| order.len() == left_ca.len()));
debug_assert!(right_order.is_none_or(|order| order.len() == right_ca.len()));

let mut left_idx = 0;
let mut right_idx = 0;
let mut match_count = 0;

while left_idx < left_ca.len() {
debug_assert!(left_ca.get(left_idx).is_some());
let left_val = unsafe { left_ca.value_unchecked(left_idx) };
while right_idx < right_ca.len() {
debug_assert!(right_ca.get(right_idx).is_some());
let right_val = unsafe { right_ca.value_unchecked(right_idx) };
if pred(&left_val, &right_val) {
// If the predicate is true, then it will also be true for all
// remaining rows from the right side.
let left_row = match left_order {
None => left_idx as IdxSize,
Some(order) => order[left_idx],
};
let right_end_idx = match slice_end {
None => right_ca.len(),
Some(end) => min(right_ca.len(), (end as usize) - match_count + right_idx),
};
for included_right_row_idx in right_idx..right_end_idx {
let right_row = match right_order {
None => included_right_row_idx as IdxSize,
Some(order) => order[included_right_row_idx],
};
left_row_idx.push(left_row);
right_row_idx.push(right_row);
}
match_count += right_end_idx - right_idx;
break;
} else {
right_idx += 1;
}
}
if right_idx == right_ca.len() {
// We've reached the end of the right side
// so there can be no more matches for LHS rows
break;
}
if slice_end.is_some_and(|end| match_count >= end as usize) {
break;
}
left_idx += 1;
}

Ok((left_row_idx, right_row_idx))
}

pub(super) fn iejoin_par(
left: &DataFrame,
right: &DataFrame,
Expand Down Expand Up @@ -206,7 +278,7 @@ pub(super) fn iejoin_par(
};

if include_block {
let (l, r) = unsafe {
let (mut l, mut r) = unsafe {
(
selected_left
.iter()
Expand All @@ -218,9 +290,21 @@ pub(super) fn iejoin_par(
.collect_vec(),
)
};
let sorted_flag = if l1_descending {
IsSorted::Descending
} else {
IsSorted::Ascending
};
// We sorted using the first series
l[0].set_sorted_flag(sorted_flag);
r[0].set_sorted_flag(sorted_flag);

// Compute the row indexes
let (idx_l, idx_r) = iejoin_tuples(l, r, options, None)?;
let (idx_l, idx_r) = if options.operator2.is_some() {
iejoin_tuples(l, r, options, None)
} else {
piecewise_merge_join_tuples(l, r, options, None)
}?;

if idx_l.is_empty() {
return Ok(None);
Expand Down Expand Up @@ -264,8 +348,11 @@ pub(super) fn iejoin(
suffix: Option<PlSmallStr>,
slice: Option<(i64, usize)>,
) -> PolarsResult<DataFrame> {
let (left_row_idx, right_row_idx) =
iejoin_tuples(selected_left, selected_right, options, slice)?;
let (left_row_idx, right_row_idx) = if options.operator2.is_some() {
iejoin_tuples(selected_left, selected_right, options, slice)
} else {
piecewise_merge_join_tuples(selected_left, selected_right, options, slice)
}?;
unsafe { materialize_join(left, right, &left_row_idx, &right_row_idx, suffix) }
}

Expand Down Expand Up @@ -308,7 +395,12 @@ fn iejoin_tuples(
};

let op1 = options.operator1;
let op2 = options.operator2;
let op2 = match options.operator2 {
None => {
return Err(polars_err!(ComputeError: "IEJoin requires two inequality operators"));
},
Some(op2) => op2,
};

// Determine the sort order based on the comparison operators used.
// We want to sort L1 so that "x[i] op1 x[j]" is true for j > i,
Expand Down Expand Up @@ -381,3 +473,142 @@ fn iejoin_tuples(
};
Ok((left_row_idx, right_row_idx))
}

/// Piecewise merge join, for joins with only a single inequality.
fn piecewise_merge_join_tuples(
selected_left: Vec<Series>,
selected_right: Vec<Series>,
options: &IEJoinOptions,
slice: Option<(i64, usize)>,
) -> PolarsResult<(IdxCa, IdxCa)> {
if selected_left.len() != 1 {
return Err(
polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the left DataFrame"),
);
};
if selected_right.len() != 1 {
return Err(
polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the right DataFrame"),
);
};
if options.operator2.is_some() {
return Err(
polars_err!(ComputeError: "Piecewise merge join expects only one inequality operator"),
);
}

let op = options.operator1;
// The left side is sorted such that if the condition is false, it will also
// be false for the same RHS row and all following LHS rows.
// The right side is sorted such that if the condition is true then it is also
// true for the same LHS row and all following RHS rows.
// The desired sort order should match the l1 order used in iejoin_par
// so we don't need to re-sort slices when doing a parallel join.
let descending = matches!(op, InequalityOperator::Gt | InequalityOperator::GtEq);

let left = selected_left[0].to_physical_repr().into_owned();
let mut right = selected_right[0].to_physical_repr().into_owned();
let must_cast = right.dtype().matches_schema_type(left.dtype())?;
if must_cast {
right = right.cast(left.dtype())?;
}

fn get_sorted(series: Series, descending: bool) -> (Series, Option<IdxCa>) {
let expected_flag = if descending {
IsSorted::Descending
} else {
IsSorted::Ascending
};
if (series.is_sorted_flag() == expected_flag || series.len() <= 1) && !series.has_nulls() {
// Fast path, no need to re-sort
(series, None)
} else {
let sort_options = SortOptions::default()
.with_nulls_last(false)
.with_order_descending(descending);

// Get order and slice to ignore any null values, which cannot be match results
let order = series
.arg_sort(sort_options)
.slice(
series.null_count() as i64,
series.len() - series.null_count(),
)
.rechunk();
let ordered = unsafe { series.take_unchecked(&order) };
(ordered, Some(order))
}
}

let (left_ordered, left_order) = get_sorted(left, descending);
debug_assert!(left_order
.as_ref()
.is_none_or(|order| order.chunks().len() == 1));
let left_order = left_order
.as_ref()
.map(|order| order.downcast_get(0).unwrap().values().as_slice());

let (right_ordered, right_order) = get_sorted(right, descending);
debug_assert!(right_order
.as_ref()
.is_none_or(|order| order.chunks().len() == 1));
let right_order = right_order
.as_ref()
.map(|order| order.downcast_get(0).unwrap().values().as_slice());

let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(left_ordered.dtype(), |$T| {
match op {
InequalityOperator::Lt => piecewise_merge_join_impl_t::<$T, _>(
slice,
left_order,
right_order,
left_ordered,
right_ordered,
|l, r| l.tot_lt(r),
),
InequalityOperator::LtEq => piecewise_merge_join_impl_t::<$T, _>(
slice,
left_order,
right_order,
left_ordered,
right_ordered,
|l, r| l.tot_le(r),
),
InequalityOperator::Gt => piecewise_merge_join_impl_t::<$T, _>(
slice,
left_order,
right_order,
left_ordered,
right_ordered,
|l, r| l.tot_gt(r),
),
InequalityOperator::GtEq => piecewise_merge_join_impl_t::<$T, _>(
slice,
left_order,
right_order,
left_ordered,
right_ordered,
|l, r| l.tot_ge(r),
),
}
})?;

debug_assert_eq!(left_row_idx.len(), right_row_idx.len());
let left_row_idx = IdxCa::from_vec("".into(), left_row_idx);
let right_row_idx = IdxCa::from_vec("".into(), right_row_idx);
let (left_row_idx, right_row_idx) = match slice {
None => (left_row_idx, right_row_idx),
Some((offset, len)) => (
left_row_idx.slice(offset, len),
right_row_idx.slice(offset, len),
),
};
Ok((left_row_idx, right_row_idx))
}

fn slice_end_index(slice: Option<(i64, usize)>) -> Option<i64> {
match slice {
Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)),
_ => None,
}
}
37 changes: 21 additions & 16 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,12 @@ fn resolve_join_where(
&suffix,
);
join_node
}
// TODO! once we support single IEjoin predicates, we must add a branch for the singe ie_pred case.
else if ie_right_on.len() >= 2 {
} else if ie_right_on.len() >= 2 {
// Do an IEjoin.
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::IEJoin(IEJoinOptions {
operator1: ie_op[0],
operator2: ie_op[1],
operator2: Some(ie_op[1]),
});

let join_node = resolve_join(
Expand All @@ -390,31 +388,38 @@ fn resolve_join_where(
remaining_preds.push(to_binary_post_join(l, op.into(), r, &schema_right, &suffix))
}
join_node
} else if ie_right_on.len() == 1 {
// For a single inequality comparison, we use the piecewise merge join algorithm
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::IEJoin(IEJoinOptions {
operator1: ie_op[0],
operator2: None,
});

resolve_join(
Either::Right(input_left),
Either::Right(input_right),
ie_left_on,
ie_right_on,
vec![],
options.clone(),
ctxt,
)?
} else {
// No predicates found that are supported in a fast algorithm.
// Do a cross join and follow up with filters.
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::Cross;

let join_node = resolve_join(
resolve_join(
Either::Right(input_left),
Either::Right(input_right),
vec![],
vec![],
vec![],
options.clone(),
ctxt,
)?;
// TODO: This can be removed once we support the single IEjoin.
ie_predicates_to_remaining(
&mut remaining_preds,
ie_left_on,
ie_right_on,
ie_op,
&schema_right,
&suffix,
);
join_node
)?
};

let IR::Join {
Expand Down
Loading

0 comments on commit b5e2fec

Please sign in to comment.