diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index a277526c6e31..50d3dcb1df6b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -83,6 +83,8 @@ pub enum AggregateMode { /// two operators. /// This mode requires tha the input is partitioned by group key (like FinalPartitioned) SinglePartitioned, + /// Combine Partials + CombinePartial, } impl AggregateMode { @@ -94,7 +96,7 @@ impl AggregateMode { AggregateMode::Partial | AggregateMode::Single | AggregateMode::SinglePartitioned => true, - AggregateMode::Final | AggregateMode::FinalPartitioned => false, + AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::CombinePartial => false, } } } @@ -651,7 +653,7 @@ impl ExecutionPlan for AggregateExec { fn required_input_distribution(&self) -> Vec { match &self.mode { - AggregateMode::Partial => { + AggregateMode::Partial | AggregateMode::CombinePartial => { vec![Distribution::UnspecifiedDistribution] } AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => { @@ -781,7 +783,7 @@ fn create_schema( } match mode { - AggregateMode::Partial => { + AggregateMode::Partial | AggregateMode::CombinePartial => { // in partial mode, the fields of the accumulator's state for expr in aggr_expr { fields.extend(expr.state_fields()?.iter().cloned()) @@ -1050,7 +1052,7 @@ fn aggregate_expressions( }) .collect()), // In this mode, we build the merge expressions of the aggregation. - AggregateMode::Final | AggregateMode::FinalPartitioned => { + AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::CombinePartial => { let mut col_idx_base = col_idx_base; aggr_expr .iter() @@ -1099,7 +1101,7 @@ fn finalize_aggregation( mode: &AggregateMode, ) -> Result> { match mode { - AggregateMode::Partial => { + AggregateMode::Partial | AggregateMode::CombinePartial => { // Build the vector of states accumulators .iter_mut() diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index 5ec95bd79942..7062e3be70a2 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -81,7 +81,8 @@ impl AggregateStream { let filter_expressions = match agg.mode { AggregateMode::Partial | AggregateMode::Single - | AggregateMode::SinglePartitioned => agg_filter_expr, + | AggregateMode::SinglePartitioned + | AggregateMode::CombinePartial => agg_filter_expr, AggregateMode::Final | AggregateMode::FinalPartitioned => { vec![None; agg.aggr_expr.len()] } @@ -230,7 +231,7 @@ fn aggregate_batch( AggregateMode::Partial | AggregateMode::Single | AggregateMode::SinglePartitioned => accum.update_batch(values), - AggregateMode::Final | AggregateMode::FinalPartitioned => { + AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::CombinePartial => { accum.merge_batch(values) } }; diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index f9db0a050cfc..1e72352c2989 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -313,7 +313,8 @@ impl GroupedHashAggregateStream { let filter_expressions = match agg.mode { AggregateMode::Partial | AggregateMode::Single - | AggregateMode::SinglePartitioned => agg_filter_expr, + | AggregateMode::SinglePartitioned + | AggregateMode::CombinePartial => agg_filter_expr, AggregateMode::Final | AggregateMode::FinalPartitioned => { vec![None; agg.aggr_expr.len()] } @@ -640,7 +641,8 @@ impl GroupedHashAggregateStream { // Next output each aggregate value for acc in self.accumulators.iter_mut() { match self.mode { - AggregateMode::Partial => output.extend(acc.state(emit_to)?), + AggregateMode::Partial + | AggregateMode::CombinePartial => output.extend(acc.state(emit_to)?), _ if spilling => { // If spilling, output partial state because the spilled data will be // merged and re-evaluated later. diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e779e29cb8da..7727849fdec4 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1604,6 +1604,7 @@ enum AggregateMode { FINAL_PARTITIONED = 2; SINGLE = 3; SINGLE_PARTITIONED = 4; + COMBINE_PARTIAL = 5; } message PartiallySortedInputOrderMode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f5f15aa3e428..d84b838e2abe 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -602,6 +602,7 @@ impl serde::Serialize for AggregateMode { Self::FinalPartitioned => "FINAL_PARTITIONED", Self::Single => "SINGLE", Self::SinglePartitioned => "SINGLE_PARTITIONED", + Self::CombinePartial => "COMBINE_PARTIAL", }; serializer.serialize_str(variant) } @@ -618,6 +619,7 @@ impl<'de> serde::Deserialize<'de> for AggregateMode { "FINAL_PARTITIONED", "SINGLE", "SINGLE_PARTITIONED", + "COMBINE_PARTIAL", ]; struct GeneratedVisitor; @@ -663,6 +665,7 @@ impl<'de> serde::Deserialize<'de> for AggregateMode { "FINAL_PARTITIONED" => Ok(AggregateMode::FinalPartitioned), "SINGLE" => Ok(AggregateMode::Single), "SINGLE_PARTITIONED" => Ok(AggregateMode::SinglePartitioned), + "COMBINE_PARTIAL" => Ok(AggregateMode::CombinePartial), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 69d035239cb8..b09fa8319d96 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -3506,6 +3506,7 @@ pub enum AggregateMode { FinalPartitioned = 2, Single = 3, SinglePartitioned = 4, + CombinePartial = 5, } impl AggregateMode { /// String value of the enum field names used in the ProtoBuf definition. @@ -3519,6 +3520,7 @@ impl AggregateMode { AggregateMode::FinalPartitioned => "FINAL_PARTITIONED", AggregateMode::Single => "SINGLE", AggregateMode::SinglePartitioned => "SINGLE_PARTITIONED", + AggregateMode::CombinePartial => "COMBINE_PARTIAL", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3529,6 +3531,7 @@ impl AggregateMode { "FINAL_PARTITIONED" => Some(Self::FinalPartitioned), "SINGLE" => Some(Self::Single), "SINGLE_PARTITIONED" => Some(Self::SinglePartitioned), + "COMBINE_PARTIAL" => Some(Self::CombinePartial), _ => None, } } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index d2961875d89a..868c942bab64 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -381,6 +381,9 @@ impl AsExecutionPlan for PhysicalPlanNode { protobuf::AggregateMode::SinglePartitioned => { AggregateMode::SinglePartitioned } + protobuf::AggregateMode::CombinePartial => { + AggregateMode::CombinePartial + } }; let num_expr = hash_agg.group_expr.len(); @@ -1390,7 +1393,9 @@ impl AsExecutionPlan for PhysicalPlanNode { AggregateMode::SinglePartitioned => { protobuf::AggregateMode::SinglePartitioned } + AggregateMode::CombinePartial => protobuf::AggregateMode::CombinePartial, }; + let input_schema = exec.input_schema(); let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(),