Skip to content

Commit

Permalink
Use naming conventions from upstream PR
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkharderdev committed Sep 19, 2024
1 parent 8b48588 commit 937f712
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 23 deletions.
43 changes: 21 additions & 22 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ use arrow_buffer::BooleanBuffer;
use futures::{ready, Stream, StreamExt, TryStreamExt};
use parking_lot::Mutex;

pub struct DistributedJoinState {
state_impl: Arc<dyn DistributedJoinStateImpl>,
pub struct SharedJoinState {
state_impl: Arc<dyn SharedJoinStateImpl>,
}

impl DistributedJoinState {
pub fn new(state_impl: Arc<dyn DistributedJoinStateImpl>) -> Self {
impl SharedJoinState {
pub fn new(state_impl: Arc<dyn SharedJoinStateImpl>) -> Self {
Self { state_impl }
}

Expand All @@ -93,27 +93,27 @@ impl DistributedJoinState {
&self,
mask: &BooleanBufferBuilder,
cx: &mut Context<'_>,
) -> Poll<Result<DistributedProbeState>> {
) -> Poll<Result<SharedProbeState>> {
self.state_impl.poll_probe_completed(mask, cx)
}
}

pub enum DistributedProbeState {
pub enum SharedProbeState {
// Probes are still running in other distributed tasks
Continue,
// Current task is last probe running so emit unmatched rows
// if required by join type
Ready(BooleanBuffer),
}

pub trait DistributedJoinStateImpl: Send + Sync + 'static {
pub trait SharedJoinStateImpl: Send + Sync + 'static {
fn num_task_partitions(&self) -> usize;

fn poll_probe_completed(
&self,
mask: &BooleanBufferBuilder,
visited_indices_bitmap: &BooleanBufferBuilder,
cx: &mut Context<'_>,
) -> Poll<Result<DistributedProbeState>>;
) -> Poll<Result<SharedProbeState>>;
}

type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>;
Expand All @@ -129,7 +129,7 @@ struct JoinLeftData {
/// Counter of running probe-threads, potentially
/// able to update `visited_indices_bitmap`
probe_threads_counter: AtomicUsize,
distributed_state: Option<Arc<DistributedJoinState>>,
shared_state: Option<Arc<SharedJoinState>>,
/// Memory reservation that tracks memory used by `hash_map` hash table
/// `batch`. Cleared on drop.
#[allow(dead_code)]
Expand All @@ -144,14 +144,14 @@ impl JoinLeftData {
visited_indices_bitmap: SharedBitmapBuilder,
probe_threads_counter: AtomicUsize,
reservation: MemoryReservation,
distributed_state: Option<Arc<DistributedJoinState>>,
distributed_state: Option<Arc<SharedJoinState>>,
) -> Self {
Self {
hash_map,
batch,
visited_indices_bitmap,
probe_threads_counter,
distributed_state,
shared_state: distributed_state,
reservation,
}
}
Expand All @@ -178,7 +178,7 @@ impl JoinLeftData {
}
}

fn merge_indices_bitmap(m1: &mut BooleanBufferBuilder, m2: BooleanBuffer) -> Result<()> {
fn merge_bitmap(m1: &mut BooleanBufferBuilder, m2: BooleanBuffer) -> Result<()> {
if m1.len() != m2.len() {
return Err(DataFusionError::Execution(format!(
"local and shared indices bitmaps have different lengths: {} and {}",
Expand Down Expand Up @@ -785,9 +785,8 @@ impl ExecutionPlan for HashJoinExec {
);
}

let distributed_state = context
.session_config()
.get_extension::<DistributedJoinState>();
let distributed_state =
context.session_config().get_extension::<SharedJoinState>();

let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
let left_fut = match self.mode {
Expand Down Expand Up @@ -916,7 +915,7 @@ async fn collect_left_input(
reservation: MemoryReservation,
with_visited_indices_bitmap: bool,
probe_threads_count: usize,
distributed_state: Option<Arc<DistributedJoinState>>,
distributed_state: Option<Arc<SharedJoinState>>,
) -> Result<JoinLeftData> {
let schema = left.schema();

Expand Down Expand Up @@ -1581,15 +1580,15 @@ impl HashJoinStream {
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}

if let Some(dist_state) = build_side.left_data.distributed_state.as_ref() {
if let Some(shared_state) = build_side.left_data.shared_state.as_ref() {
let mut guard = build_side.left_data.visited_indices_bitmap().lock();
match ready!(dist_state.poll_probe_completed(guard.deref(), cx)) {
Ok(DistributedProbeState::Continue) => {
match ready!(shared_state.poll_probe_completed(guard.deref(), cx)) {
Ok(SharedProbeState::Continue) => {
self.state = HashJoinStreamState::Completed;
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
Ok(DistributedProbeState::Ready(shared_mask)) => {
if let Err(e) = merge_indices_bitmap(guard.deref_mut(), shared_mask) {
Ok(SharedProbeState::Ready(shared_mask)) => {
if let Err(e) = merge_bitmap(guard.deref_mut(), shared_mask) {
return Poll::Ready(Err(e));
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-plan/src/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

pub use cross_join::CrossJoinExec;
pub use hash_join::{
DistributedJoinState, DistributedJoinStateImpl, DistributedProbeState, HashJoinExec,
HashJoinExec, SharedJoinState, SharedJoinStateImpl, SharedProbeState,
};
pub use nested_loop_join::NestedLoopJoinExec;
// Note: SortMergeJoin is not used in plans yet
Expand Down

0 comments on commit 937f712

Please sign in to comment.