Skip to content

Commit ddb672f

Browse files
committed
feat: Support sum in window function
1 parent 25f69bc commit ddb672f

File tree

3 files changed

+39
-32
lines changed

3 files changed

+39
-32
lines changed

native/core/src/execution/datafusion/planner.rs

+31-30
Original file line numberDiff line numberDiff line change
@@ -1531,12 +1531,17 @@ impl PhysicalPlanner {
15311531
partition_by: &[Arc<dyn PhysicalExpr>],
15321532
sort_exprs: &[PhysicalSortExpr],
15331533
) -> Result<Arc<dyn WindowExpr>, ExecutionError> {
1534-
let (mut window_func_name, mut window_func_args) = (String::new(), Vec::new());
1534+
let window_func_name: String;
1535+
let window_args: Vec<Arc<dyn PhysicalExpr>>;
15351536
if let Some(func) = &spark_expr.built_in_window_function {
15361537
match &func.expr_struct {
15371538
Some(ExprStruct::ScalarFunc(f)) => {
1538-
window_func_name.clone_from(&f.func);
1539-
window_func_args.clone_from(&f.args);
1539+
window_func_name = f.func.clone();
1540+
window_args = f
1541+
.args
1542+
.iter()
1543+
.map(|expr| self.create_expr(expr, input_schema.clone()))
1544+
.collect::<Result<Vec<_>, ExecutionError>>()?;
15401545
}
15411546
other => {
15421547
return Err(ExecutionError::GeneralError(format!(
@@ -1545,9 +1550,9 @@ impl PhysicalPlanner {
15451550
}
15461551
};
15471552
} else if let Some(agg_func) = &spark_expr.agg_func {
1548-
let result = Self::process_agg_func(agg_func)?;
1553+
let result = self.process_agg_func(agg_func, input_schema.clone())?;
15491554
window_func_name = result.0;
1550-
window_func_args = result.1;
1555+
window_args = result.1;
15511556
} else {
15521557
return Err(ExecutionError::GeneralError(
15531558
"Both func and agg_func are not set".to_string(),
@@ -1563,11 +1568,6 @@ impl PhysicalPlanner {
15631568
}
15641569
};
15651570

1566-
let window_args = window_func_args
1567-
.iter()
1568-
.map(|expr| self.create_expr(expr, input_schema.clone()))
1569-
.collect::<Result<Vec<_>, ExecutionError>>()?;
1570-
15711571
let spark_window_frame = match spark_expr
15721572
.spec
15731573
.as_ref()
@@ -1637,32 +1637,33 @@ impl PhysicalPlanner {
16371637
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
16381638
}
16391639

1640-
fn process_agg_func(agg_func: &AggExpr) -> Result<(String, Vec<Expr>), ExecutionError> {
1641-
fn optional_expr_to_vec(expr_option: &Option<Expr>) -> Vec<Expr> {
1642-
expr_option
1643-
.as_ref()
1644-
.cloned()
1645-
.map_or_else(Vec::new, |e| vec![e])
1646-
}
1647-
1648-
fn int_to_stats_type(value: i32) -> Option<StatsType> {
1649-
match value {
1650-
0 => Some(StatsType::Sample),
1651-
1 => Some(StatsType::Population),
1652-
_ => None,
1653-
}
1654-
}
1655-
1640+
fn process_agg_func(
1641+
&self,
1642+
agg_func: &AggExpr,
1643+
schema: SchemaRef,
1644+
) -> Result<(String, Vec<Arc<dyn PhysicalExpr>>), ExecutionError> {
16561645
match &agg_func.expr_struct {
16571646
Some(AggExprStruct::Count(expr)) => {
1658-
let args = &expr.children;
1659-
Ok(("count".to_string(), args.to_vec()))
1647+
let children = expr
1648+
.children
1649+
.iter()
1650+
.map(|child| self.create_expr(child, schema.clone()))
1651+
.collect::<Result<Vec<_>, _>>()?;
1652+
Ok(("count".to_string(), children))
16601653
}
16611654
Some(AggExprStruct::Min(expr)) => {
1662-
Ok(("min".to_string(), optional_expr_to_vec(&expr.child)))
1655+
let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?;
1656+
Ok(("min".to_string(), vec![child]))
16631657
}
16641658
Some(AggExprStruct::Max(expr)) => {
1665-
Ok(("max".to_string(), optional_expr_to_vec(&expr.child)))
1659+
let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?;
1660+
Ok(("max".to_string(), vec![child]))
1661+
}
1662+
Some(AggExprStruct::Sum(expr)) => {
1663+
let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?;
1664+
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
1665+
let child = Arc::new(CastExpr::new(child, datatype.clone(), None));
1666+
Ok(("sum".to_string(), vec![child]))
16661667
}
16671668
other => Err(ExecutionError::GeneralError(format!(
16681669
"{other:?} not supported for window function"

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
209209
expr match {
210210
case agg: AggregateExpression =>
211211
agg.aggregateFunction match {
212-
case _: Min | _: Max | _: Count =>
212+
case _: Min | _: Max | _: Count | _: Sum =>
213213
Some(agg)
214214
case _ =>
215215
withInfo(windowExpr, "Unsupported aggregate", expr)

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -1510,7 +1510,13 @@ class CometExecSuite extends CometTestBase {
15101510
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) {
15111511
withParquetTable((0 until 10).map(i => (i, 10 - i)), "t1") { // TODO: test nulls
15121512
val aggregateFunctions =
1513-
List("COUNT(_1)", "COUNT(*)", "MAX(_1)", "MIN(_1)") // TODO: Test all the aggregates
1513+
List(
1514+
"COUNT(_1)",
1515+
"COUNT(*)",
1516+
"MAX(_1)",
1517+
"MIN(_1)",
1518+
"SUM(_1)"
1519+
) // TODO: Test all the aggregates
15141520

15151521
aggregateFunctions.foreach { function =>
15161522
val queries = Seq(

0 commit comments

Comments
 (0)