Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hook for doing distributed CollectLeft joins #269

Merged
merged 11 commits into from
Sep 20, 2024
133 changes: 119 additions & 14 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@

//! [`HashJoinExec`] Partitioned Hash Join Operator

use std::fmt;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::Poll;
use std::{any::Any, vec};

use super::{
utils::{OnceAsync, OnceFut},
PartitionMode,
Expand All @@ -46,6 +40,12 @@ use crate::{
Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
Statistics,
};
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, vec};

use arrow::array::{
Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array,
Expand All @@ -72,9 +72,56 @@ use datafusion_physical_expr::expressions::UnKnownColumn;
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};

use ahash::RandomState;
use arrow_buffer::BooleanBuffer;
use futures::{ready, Stream, StreamExt, TryStreamExt};
use parking_lot::Mutex;

pub struct SharedJoinState {
state_impl: Arc<dyn SharedJoinStateImpl>,
}

impl SharedJoinState {
pub fn new(state_impl: Arc<dyn SharedJoinStateImpl>) -> Self {
Self { state_impl }
}

fn num_task_partitions(&self) -> usize {
self.state_impl.num_task_partitions()
}

fn poll_probe_completed(
&self,
mask: &BooleanBufferBuilder,
cx: &mut Context<'_>,
) -> Poll<Result<SharedProbeState>> {
self.state_impl.poll_probe_completed(mask, cx)
}

fn register_metrics(&self, metrics: &ExecutionPlanMetricsSet, partition: usize) {
self.state_impl.register_metrics(metrics, partition)
}
}

pub enum SharedProbeState {
// Probes are still running in other distributed tasks
Continue,
// Current task is last probe running so emit unmatched rows
// if required by join type
Ready(BooleanBuffer),
}

pub trait SharedJoinStateImpl: Send + Sync + 'static {
fn num_task_partitions(&self) -> usize;

fn poll_probe_completed(
&self,
visited_indices_bitmap: &BooleanBufferBuilder,
cx: &mut Context<'_>,
) -> Poll<Result<SharedProbeState>>;

fn register_metrics(&self, metrics: &ExecutionPlanMetricsSet, partition: usize);
}

type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>;

/// HashTable and input data for the left (build side) of a join
Expand All @@ -88,6 +135,7 @@ struct JoinLeftData {
/// Counter of running probe-threads, potentially
/// able to update `visited_indices_bitmap`
probe_threads_counter: AtomicUsize,
shared_state: Option<Arc<SharedJoinState>>,
/// Memory reservation that tracks memory used by `hash_map` hash table
/// `batch`. Cleared on drop.
#[allow(dead_code)]
Expand All @@ -102,12 +150,14 @@ impl JoinLeftData {
visited_indices_bitmap: SharedBitmapBuilder,
probe_threads_counter: AtomicUsize,
reservation: MemoryReservation,
distributed_state: Option<Arc<SharedJoinState>>,
) -> Self {
Self {
hash_map,
batch,
visited_indices_bitmap,
probe_threads_counter,
shared_state: distributed_state,
reservation,
}
}
Expand All @@ -126,14 +176,34 @@ impl JoinLeftData {
fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder {
&self.visited_indices_bitmap
}

/// Decrements the counter of running threads, and returns `true`
/// if caller is the last running thread
fn report_probe_completed(&self) -> bool {
self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
self.probe_threads_counter.load(Ordering::Relaxed) == 0
|| self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
}
}

fn merge_bitmap(m1: &mut BooleanBufferBuilder, m2: BooleanBuffer) -> Result<()> {
if m1.len() != m2.len() {
return Err(DataFusionError::Execution(format!(
"local and shared indices bitmaps have different lengths: {} and {}",
m1.len(),
m2.len()
)));
}

for (b1, b2) in m1
.as_slice_mut()
.iter_mut()
.zip(m2.inner().as_slice().iter().copied())
{
*b1 |= b2;
}

Ok(())
}

/// Join execution plan: Evaluates eqijoin predicates in parallel on multiple
/// partitions using a hash table and an optional filter list to apply post
/// join.
Expand Down Expand Up @@ -721,11 +791,25 @@ impl ExecutionPlan for HashJoinExec {
);
}

let distributed_state =
context.session_config().get_extension::<SharedJoinState>();

let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
let left_fut = match self.mode {
PartitionMode::CollectLeft => self.left_fut.once(|| {
let reservation =
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());

let probe_threads = distributed_state
.as_ref()
.map(|s| {
s.register_metrics(&self.metrics, partition);
s.num_task_partitions()
})
.unwrap_or_else(|| {
self.right().output_partitioning().partition_count()
});

collect_left_input(
None,
self.random_state.clone(),
Expand All @@ -735,7 +819,8 @@ impl ExecutionPlan for HashJoinExec {
join_metrics.clone(),
reservation,
need_produce_result_in_final(self.join_type),
self.right().output_partitioning().partition_count(),
probe_threads,
distributed_state,
)
}),
PartitionMode::Partitioned => {
Expand All @@ -753,6 +838,7 @@ impl ExecutionPlan for HashJoinExec {
reservation,
need_produce_result_in_final(self.join_type),
1,
None,
))
}
PartitionMode::Auto => {
Expand Down Expand Up @@ -838,6 +924,7 @@ async fn collect_left_input(
reservation: MemoryReservation,
with_visited_indices_bitmap: bool,
probe_threads_count: usize,
distributed_state: Option<Arc<SharedJoinState>>,
) -> Result<JoinLeftData> {
let schema = left.schema();

Expand Down Expand Up @@ -925,6 +1012,7 @@ async fn collect_left_input(
Mutex::new(visited_indices_bitmap),
AtomicUsize::new(probe_threads_count),
reservation,
distributed_state,
);

Ok(data)
Expand Down Expand Up @@ -1301,7 +1389,7 @@ impl HashJoinStream {
handle_state!(self.process_probe_batch())
}
HashJoinStreamState::ExhaustedProbeSide => {
handle_state!(self.process_unmatched_build_batch())
handle_state!(ready!(self.process_unmatched_build_batch(cx)))
}
HashJoinStreamState::Completed => Poll::Ready(None),
};
Expand Down Expand Up @@ -1486,18 +1574,35 @@ impl HashJoinStream {
/// Updates state to `Completed`
fn process_unmatched_build_batch(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
cx: &mut Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let timer = self.join_metrics.join_time.timer();

if !need_produce_result_in_final(self.join_type) {
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}

let build_side = self.build_side.try_as_ready()?;
if !build_side.left_data.report_probe_completed() {
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}

if let Some(shared_state) = build_side.left_data.shared_state.as_ref() {
let mut guard = build_side.left_data.visited_indices_bitmap().lock();
match ready!(shared_state.poll_probe_completed(guard.deref(), cx)) {
Ok(SharedProbeState::Continue) => {
self.state = HashJoinStreamState::Completed;
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
Ok(SharedProbeState::Ready(shared_mask)) => {
if let Err(e) = merge_bitmap(guard.deref_mut(), shared_mask) {
return Poll::Ready(Err(e));
}
}
Err(err) => return Poll::Ready(Err(err)),
}
}

// use the global left bitmap to produce the left indices and right indices
Expand Down Expand Up @@ -1528,7 +1633,7 @@ impl HashJoinStream {

self.state = HashJoinStreamState::Completed;

Ok(StatefulStreamResult::Ready(Some(result?)))
Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result?))))
}
}

Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-plan/src/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
//! DataFusion Join implementations

pub use cross_join::CrossJoinExec;
pub use hash_join::HashJoinExec;
pub use hash_join::{
HashJoinExec, SharedJoinState, SharedJoinStateImpl, SharedProbeState,
};
pub use nested_loop_join::NestedLoopJoinExec;
// Note: SortMergeJoin is not used in plans yet
pub use sort_merge_join::SortMergeJoinExec;
Expand Down
Loading