diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index d1281296d7f4..30f60fa1b1e6 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -85,6 +85,10 @@ impl DistributedJoinState { Self { state_impl } } + fn num_task_partitions(&self) -> usize { + self.state_impl.num_task_partitions() + } + fn poll_probe_completed( &self, mask: &BooleanBufferBuilder, @@ -103,6 +107,8 @@ pub enum DistributedProbeState { } pub trait DistributedJoinStateImpl: Send + Sync + 'static { + fn num_task_partitions(&self) -> usize; + fn poll_probe_completed( &self, mask: &BooleanBufferBuilder, @@ -788,6 +794,13 @@ impl ExecutionPlan for HashJoinExec { let reservation = MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); + let probe_threads = distributed_state + .as_ref() + .map(|s| s.num_task_partitions()) + .unwrap_or_else(|| { + self.right().output_partitioning().partition_count() + }); + collect_left_input( None, self.random_state.clone(), @@ -797,7 +810,7 @@ impl ExecutionPlan for HashJoinExec { join_metrics.clone(), reservation, need_produce_result_in_final(self.join_type), - self.right().output_partitioning().partition_count(), + probe_threads, distributed_state, ) }),