diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index e9d6354e21d7..fcd5fd5a5a07 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -40,7 +40,6 @@ use crate::projection::{ use crate::spill::get_record_batch_memory_size; use crate::ExecutionPlanProperties; use crate::{ - coalesce_partitions::CoalescePartitionsExec, common::can_project, handle_state, hash_utils::create_hashes, @@ -791,34 +790,44 @@ impl ExecutionPlan for HashJoinExec { ); } + if self.mode == PartitionMode::CollectLeft && left_partitions != 1 { + return internal_err!( + "Invalid HashJoinExec,the output partition count of the left child must be 1 in CollectLeft mode,\ + consider using CoalescePartitionsExec" + ); + } + let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); let left_fut = match self.mode { - PartitionMode::CollectLeft => self.left_fut.once(|| { - let reservation = - MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); - collect_left_input( - None, - self.random_state.clone(), - Arc::clone(&self.left), - on_left.clone(), - Arc::clone(&context), - join_metrics.clone(), - reservation, - need_produce_result_in_final(self.join_type), - self.right().output_partitioning().partition_count(), - ) - }), + PartitionMode::CollectLeft => { + let left_stream = self.left.execute(0, Arc::clone(&context))?; + + self.left_fut.once(|| { + let reservation = MemoryConsumer::new("HashJoinInput") + .register(context.memory_pool()); + + collect_left_input( + self.random_state.clone(), + left_stream, + on_left.clone(), + join_metrics.clone(), + reservation, + need_produce_result_in_final(self.join_type), + self.right().output_partitioning().partition_count(), + ) + }) + } PartitionMode::Partitioned => { + let left_stream = self.left.execute(partition, Arc::clone(&context))?; + let reservation = MemoryConsumer::new(format!("HashJoinInput[{partition}]")) .register(context.memory_pool()); OnceFut::new(collect_left_input( - Some(partition), self.random_state.clone(), - Arc::clone(&self.left), + left_stream, on_left.clone(), - Arc::clone(&context), join_metrics.clone(), reservation, need_produce_result_in_final(self.join_type), @@ -929,36 +938,22 @@ impl ExecutionPlan for HashJoinExec { /// Reads the left (build) side of the input, buffering it in memory, to build a /// hash table (`LeftJoinData`) -#[allow(clippy::too_many_arguments)] async fn collect_left_input( - partition: Option, random_state: RandomState, - left: Arc, + left_stream: SendableRecordBatchStream, on_left: Vec, - context: Arc, metrics: BuildProbeJoinMetrics, reservation: MemoryReservation, with_visited_indices_bitmap: bool, probe_threads_count: usize, ) -> Result { - let schema = left.schema(); - - let (left_input, left_input_partition) = if let Some(partition) = partition { - (left, partition) - } else if left.output_partitioning().partition_count() != 1 { - (Arc::new(CoalescePartitionsExec::new(left)) as _, 0) - } else { - (left, 0) - }; - - // Depending on partition argument load single partition or whole left side in memory - let stream = left_input.execute(left_input_partition, Arc::clone(&context))?; + let schema = left_stream.schema(); // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream // 2. stores the batches in a vector. let initial = (Vec::new(), 0, metrics, reservation); - let (batches, num_rows, metrics, mut reservation) = stream + let (batches, num_rows, metrics, mut reservation) = left_stream .try_fold(initial, |mut acc, batch| async { let batch_size = get_record_batch_memory_size(&batch); // Reserve memory for incoming batch @@ -1654,6 +1649,7 @@ impl EmbeddedProjection for HashJoinExec { #[cfg(test)] mod tests { use super::*; + use crate::coalesce_partitions::CoalescePartitionsExec; use crate::test::TestMemoryExec; use crate::{ common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, @@ -2101,6 +2097,7 @@ mod tests { let left = TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None) .unwrap(); + let left = Arc::new(CoalescePartitionsExec::new(left)); let right = build_table( ("a1", &vec![1, 2, 3]), @@ -2173,6 +2170,7 @@ mod tests { let left = TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None) .unwrap(); + let left = Arc::new(CoalescePartitionsExec::new(left)); let right = build_table( ("a2", &vec![20, 30, 10]), ("b2", &vec![5, 6, 4]),