From 602a5f653ff3276f8b8dcaf38af499d8ab3bbe6e Mon Sep 17 00:00:00 2001 From: huaxingao Date: Fri, 9 Aug 2024 16:21:36 -0700 Subject: [PATCH] feat: Support sum in window function --- .../core/src/execution/datafusion/planner.rs | 121 +++++++++--------- .../apache/comet/serde/QueryPlanSerde.scala | 2 +- .../apache/comet/exec/CometExecSuite.scala | 8 +- 3 files changed, 67 insertions(+), 64 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index b604e98ba8..c9d7e17e75 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1233,38 +1233,10 @@ impl PhysicalPlanner { ) -> Result, ExecutionError> { match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { - assert!(!expr.children.is_empty()); - // Using `count_udaf` from Comet is exceptionally slow for some reason, so - // as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))` - // https://github.com/apache/datafusion-comet/issues/744 - - let children = expr - .children - .iter() - .map(|child| self.create_expr(child, schema.clone())) - .collect::, _>>()?; - - // create `IS NOT NULL expr` and join them with `AND` if there are multiple - let not_null_expr: Arc = children.iter().skip(1).fold( - Arc::new(IsNotNullExpr::new(children[0].clone())) as Arc, - |acc, child| { - Arc::new(BinaryExpr::new( - acc, - DataFusionOperator::And, - Arc::new(IsNotNullExpr::new(child.clone())), - )) - }, - ); - - let child = Arc::new(IfExpr::new( - not_null_expr, - Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), - Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), - )); - + let if_expr = self.convert_count_to_if(&expr.children, schema.clone())?; create_aggregate_expr( &sum_udaf(), - &[child], + &[if_expr], &[], &[], &[], @@ -1277,8 +1249,6 @@ impl PhysicalPlanner { } AggExprStruct::Min(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - let child = Arc::new(CastExpr::new(child, datatype.clone(), None)); create_aggregate_expr( &min_udaf(), &[child], @@ -1294,8 +1264,6 @@ impl PhysicalPlanner { } AggExprStruct::Max(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - let child = Arc::new(CastExpr::new(child, datatype.clone(), None)); create_aggregate_expr( &max_udaf(), &[child], @@ -1523,6 +1491,40 @@ impl PhysicalPlanner { } } + fn convert_count_to_if( + &self, + children: &Vec, + schema: Arc, + ) -> Result, ExecutionError> { + assert!(!children.is_empty(), "Children should not be empty"); + + // Translate `COUNT` to `SUM(IF(expr IS NOT NULL, 1, 0))` + let children_exprs = children + .iter() + .map(|child| self.create_expr(child, schema.clone())) + .collect::, _>>()?; + + // Create `IS NOT NULL expr` and combine with `AND` for multiple children + let not_null_expr = children_exprs.iter().skip(1).fold( + Arc::new(IsNotNullExpr::new(children_exprs[0].clone())) as Arc, + |acc, child| { + Arc::new(BinaryExpr::new( + acc, + DataFusionOperator::And, + Arc::new(IsNotNullExpr::new(child.clone())), + )) + }, + ); + + let if_expr = Arc::new(IfExpr::new( + not_null_expr, + Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); + + Ok(if_expr) + } + /// Create a DataFusion windows physical expression from Spark physical expression fn create_window_expr<'a>( &'a self, @@ -1531,12 +1533,17 @@ impl PhysicalPlanner { partition_by: &[Arc], sort_exprs: &[PhysicalSortExpr], ) -> Result, ExecutionError> { - let (mut window_func_name, mut window_func_args) = (String::new(), Vec::new()); + let (mut window_func_name, mut window_args): (String, Vec>) = (String::new(), Vec::new()); if let Some(func) = &spark_expr.built_in_window_function { match &func.expr_struct { Some(ExprStruct::ScalarFunc(f)) => { window_func_name.clone_from(&f.func); - window_func_args.clone_from(&f.args); + + window_args = f.args + .iter() + .map(|expr| self.create_expr(expr, input_schema.clone())) + .collect::, ExecutionError>>()?; + } other => { return Err(ExecutionError::GeneralError(format!( @@ -1545,9 +1552,9 @@ impl PhysicalPlanner { } }; } else if let Some(agg_func) = &spark_expr.agg_func { - let result = Self::process_agg_func(agg_func)?; + let result = self.process_agg_func(agg_func, input_schema.clone())?; window_func_name = result.0; - window_func_args = result.1; + window_args = result.1; } else { return Err(ExecutionError::GeneralError( "Both func and agg_func are not set".to_string(), @@ -1563,11 +1570,6 @@ impl PhysicalPlanner { } }; - let window_args = window_func_args - .iter() - .map(|expr| self.create_expr(expr, input_schema.clone())) - .collect::, ExecutionError>>()?; - let spark_window_frame = match spark_expr .spec .as_ref() @@ -1637,32 +1639,27 @@ impl PhysicalPlanner { .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } - fn process_agg_func(agg_func: &AggExpr) -> Result<(String, Vec), ExecutionError> { - fn optional_expr_to_vec(expr_option: &Option) -> Vec { - expr_option - .as_ref() - .cloned() - .map_or_else(Vec::new, |e| vec![e]) - } - - fn int_to_stats_type(value: i32) -> Option { - match value { - 0 => Some(StatsType::Sample), - 1 => Some(StatsType::Population), - _ => None, - } - } + fn process_agg_func(&self, agg_func: &AggExpr, schema: SchemaRef) -> Result<(String, Vec>), ExecutionError> { match &agg_func.expr_struct { Some(AggExprStruct::Count(expr)) => { - let args = &expr.children; - Ok(("count".to_string(), args.to_vec())) + let if_expr = self.convert_count_to_if(&expr.children, schema.clone())?; + Ok(("count".to_string(), vec![if_expr])) + } Some(AggExprStruct::Min(expr)) => { - Ok(("min".to_string(), optional_expr_to_vec(&expr.child))) + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + Ok(("min".to_string(), vec![child])) } Some(AggExprStruct::Max(expr)) => { - Ok(("max".to_string(), optional_expr_to_vec(&expr.child))) + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + Ok(("max".to_string(), vec![child])) + } + Some(AggExprStruct::Sum(expr)) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let child = Arc::new(CastExpr::new(child, datatype.clone(), None)); + Ok(("sum".to_string(), vec![child])) } other => Err(ExecutionError::GeneralError(format!( "{other:?} not supported for window function" diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 135ed15b9e..8118488c76 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -209,7 +209,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim expr match { case agg: AggregateExpression => agg.aggregateFunction match { - case _: Min | _: Max | _: Count => + case _: Min | _: Max | _: Count | _: Sum | _: Min | _: Max => Some(agg) case _ => withInfo(windowExpr, "Unsupported aggregate", expr) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 56da81cbf7..4c60ad59c6 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -1510,7 +1510,13 @@ class CometExecSuite extends CometTestBase { SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) { withParquetTable((0 until 10).map(i => (i, 10 - i)), "t1") { // TODO: test nulls val aggregateFunctions = - List("COUNT(_1)", "COUNT(*)", "MAX(_1)", "MIN(_1)") // TODO: Test all the aggregates + List( + "COUNT(_1)", + "COUNT(*)", + "MAX(_1)", + "MIN(_1)", + "SUM(_1)" + ) // TODO: Test all the aggregates aggregateFunctions.foreach { function => val queries = Seq(