Skip to content

Commit

Permalink
Get probe threads from distributed state
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkharderdev committed Sep 14, 2024
1 parent 6fd7e2c commit 3a6725f
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ impl DistributedJoinState {
Self { state_impl }
}

fn num_task_partitions(&self) -> usize {
self.state_impl.num_task_partitions()
}

fn poll_probe_completed(
&self,
mask: &BooleanBufferBuilder,
Expand All @@ -103,6 +107,8 @@ pub enum DistributedProbeState {
}

pub trait DistributedJoinStateImpl: Send + Sync + 'static {
fn num_task_partitions(&self) -> usize;

fn poll_probe_completed(
&self,
mask: &BooleanBufferBuilder,
Expand Down Expand Up @@ -788,6 +794,13 @@ impl ExecutionPlan for HashJoinExec {
let reservation =
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());

let probe_threads = distributed_state
.as_ref()
.map(|s| s.num_task_partitions())
.unwrap_or_else(|| {
self.right().output_partitioning().partition_count()
});

collect_left_input(
None,
self.random_state.clone(),
Expand All @@ -797,7 +810,7 @@ impl ExecutionPlan for HashJoinExec {
join_metrics.clone(),
reservation,
need_produce_result_in_final(self.join_type),
self.right().output_partitioning().partition_count(),
probe_threads,
distributed_state,
)
}),
Expand Down

0 comments on commit 3a6725f

Please sign in to comment.