Skip to content

Commit

Permalink
feat: Support sum in window function
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxingao committed Aug 9, 2024
1 parent 25f69bc commit 602a5f6
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 64 deletions.
121 changes: 59 additions & 62 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1233,38 +1233,10 @@ impl PhysicalPlanner {
) -> Result<Arc<dyn AggregateExpr>, ExecutionError> {
match spark_expr.expr_struct.as_ref().unwrap() {
AggExprStruct::Count(expr) => {
assert!(!expr.children.is_empty());
// Using `count_udaf` from Comet is exceptionally slow for some reason, so
// as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))`
// https://github.com/apache/datafusion-comet/issues/744

let children = expr
.children
.iter()
.map(|child| self.create_expr(child, schema.clone()))
.collect::<Result<Vec<_>, _>>()?;

// create `IS NOT NULL expr` and join them with `AND` if there are multiple
let not_null_expr: Arc<dyn PhysicalExpr> = children.iter().skip(1).fold(
Arc::new(IsNotNullExpr::new(children[0].clone())) as Arc<dyn PhysicalExpr>,
|acc, child| {
Arc::new(BinaryExpr::new(
acc,
DataFusionOperator::And,
Arc::new(IsNotNullExpr::new(child.clone())),
))
},
);

let child = Arc::new(IfExpr::new(
not_null_expr,
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
Arc::new(Literal::new(ScalarValue::Int64(Some(0)))),
));

let if_expr = self.convert_count_to_if(&expr.children, schema.clone())?;
create_aggregate_expr(
&sum_udaf(),
&[child],
&[if_expr],
&[],
&[],
&[],
Expand All @@ -1277,8 +1249,6 @@ impl PhysicalPlanner {
}
AggExprStruct::Min(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let child = Arc::new(CastExpr::new(child, datatype.clone(), None));
create_aggregate_expr(
&min_udaf(),
&[child],
Expand All @@ -1294,8 +1264,6 @@ impl PhysicalPlanner {
}
AggExprStruct::Max(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let child = Arc::new(CastExpr::new(child, datatype.clone(), None));
create_aggregate_expr(
&max_udaf(),
&[child],
Expand Down Expand Up @@ -1523,6 +1491,40 @@ impl PhysicalPlanner {
}
}

fn convert_count_to_if(
&self,
children: &Vec<Expr>,
schema: Arc<Schema>,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
assert!(!children.is_empty(), "Children should not be empty");

// Translate `COUNT` to `SUM(IF(expr IS NOT NULL, 1, 0))`
let children_exprs = children
.iter()
.map(|child| self.create_expr(child, schema.clone()))
.collect::<Result<Vec<_>, _>>()?;

// Create `IS NOT NULL expr` and combine with `AND` for multiple children
let not_null_expr = children_exprs.iter().skip(1).fold(
Arc::new(IsNotNullExpr::new(children_exprs[0].clone())) as Arc<dyn PhysicalExpr>,
|acc, child| {
Arc::new(BinaryExpr::new(
acc,
DataFusionOperator::And,
Arc::new(IsNotNullExpr::new(child.clone())),
))
},
);

let if_expr = Arc::new(IfExpr::new(
not_null_expr,
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
Arc::new(Literal::new(ScalarValue::Int64(Some(0)))),
));

Ok(if_expr)
}

/// Create a DataFusion windows physical expression from Spark physical expression
fn create_window_expr<'a>(
&'a self,
Expand All @@ -1531,12 +1533,17 @@ impl PhysicalPlanner {
partition_by: &[Arc<dyn PhysicalExpr>],
sort_exprs: &[PhysicalSortExpr],
) -> Result<Arc<dyn WindowExpr>, ExecutionError> {
let (mut window_func_name, mut window_func_args) = (String::new(), Vec::new());
let (mut window_func_name, mut window_args): (String, Vec<Arc<dyn PhysicalExpr>>) = (String::new(), Vec::new());
if let Some(func) = &spark_expr.built_in_window_function {
match &func.expr_struct {
Some(ExprStruct::ScalarFunc(f)) => {
window_func_name.clone_from(&f.func);
window_func_args.clone_from(&f.args);

window_args = f.args
.iter()
.map(|expr| self.create_expr(expr, input_schema.clone()))
.collect::<Result<Vec<_>, ExecutionError>>()?;

}
other => {
return Err(ExecutionError::GeneralError(format!(
Expand All @@ -1545,9 +1552,9 @@ impl PhysicalPlanner {
}
};
} else if let Some(agg_func) = &spark_expr.agg_func {
let result = Self::process_agg_func(agg_func)?;
let result = self.process_agg_func(agg_func, input_schema.clone())?;
window_func_name = result.0;
window_func_args = result.1;
window_args = result.1;
} else {
return Err(ExecutionError::GeneralError(
"Both func and agg_func are not set".to_string(),
Expand All @@ -1563,11 +1570,6 @@ impl PhysicalPlanner {
}
};

let window_args = window_func_args
.iter()
.map(|expr| self.create_expr(expr, input_schema.clone()))
.collect::<Result<Vec<_>, ExecutionError>>()?;

let spark_window_frame = match spark_expr
.spec
.as_ref()
Expand Down Expand Up @@ -1637,32 +1639,27 @@ impl PhysicalPlanner {
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}

fn process_agg_func(agg_func: &AggExpr) -> Result<(String, Vec<Expr>), ExecutionError> {
fn optional_expr_to_vec(expr_option: &Option<Expr>) -> Vec<Expr> {
expr_option
.as_ref()
.cloned()
.map_or_else(Vec::new, |e| vec![e])
}

fn int_to_stats_type(value: i32) -> Option<StatsType> {
match value {
0 => Some(StatsType::Sample),
1 => Some(StatsType::Population),
_ => None,
}
}
fn process_agg_func(&self, agg_func: &AggExpr, schema: SchemaRef) -> Result<(String, Vec<Arc<dyn PhysicalExpr>>), ExecutionError> {

match &agg_func.expr_struct {
Some(AggExprStruct::Count(expr)) => {
let args = &expr.children;
Ok(("count".to_string(), args.to_vec()))
let if_expr = self.convert_count_to_if(&expr.children, schema.clone())?;
Ok(("count".to_string(), vec![if_expr]))

}
Some(AggExprStruct::Min(expr)) => {
Ok(("min".to_string(), optional_expr_to_vec(&expr.child)))
let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?;
Ok(("min".to_string(), vec![child]))
}
Some(AggExprStruct::Max(expr)) => {
Ok(("max".to_string(), optional_expr_to_vec(&expr.child)))
let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?;
Ok(("max".to_string(), vec![child]))
}
Some(AggExprStruct::Sum(expr)) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let child = Arc::new(CastExpr::new(child, datatype.clone(), None));
Ok(("sum".to_string(), vec![child]))
}
other => Err(ExecutionError::GeneralError(format!(
"{other:?} not supported for window function"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
expr match {
case agg: AggregateExpression =>
agg.aggregateFunction match {
case _: Min | _: Max | _: Count =>
case _: Min | _: Max | _: Count | _: Sum | _: Min | _: Max =>
Some(agg)
case _ =>
withInfo(windowExpr, "Unsupported aggregate", expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1510,7 +1510,13 @@ class CometExecSuite extends CometTestBase {
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) {
withParquetTable((0 until 10).map(i => (i, 10 - i)), "t1") { // TODO: test nulls
val aggregateFunctions =
List("COUNT(_1)", "COUNT(*)", "MAX(_1)", "MIN(_1)") // TODO: Test all the aggregates
List(
"COUNT(_1)",
"COUNT(*)",
"MAX(_1)",
"MIN(_1)",
"SUM(_1)"
) // TODO: Test all the aggregates

aggregateFunctions.foreach { function =>
val queries = Seq(
Expand Down

0 comments on commit 602a5f6

Please sign in to comment.