Skip to content

Commit d82d151

Browse files
committed
fix test failure
1 parent 46c40b3 commit d82d151

File tree

2 files changed

+37
-39
lines changed

2 files changed

+37
-39
lines changed

native/core/src/execution/datafusion/planner.rs

+36-38
Original file line numberDiff line numberDiff line change
@@ -1233,10 +1233,38 @@ impl PhysicalPlanner {
12331233
) -> Result<Arc<dyn AggregateExpr>, ExecutionError> {
12341234
match spark_expr.expr_struct.as_ref().unwrap() {
12351235
AggExprStruct::Count(expr) => {
1236-
let if_expr = self.convert_count_to_if(&expr.children, schema.clone())?;
1236+
assert!(!expr.children.is_empty());
1237+
// Using `count_udaf` from Comet is exceptionally slow for some reason, so
1238+
// as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))`
1239+
// https://github.com/apache/datafusion-comet/issues/744
1240+
1241+
let children = expr
1242+
.children
1243+
.iter()
1244+
.map(|child| self.create_expr(child, schema.clone()))
1245+
.collect::<Result<Vec<_>, _>>()?;
1246+
1247+
// create `IS NOT NULL expr` and join them with `AND` if there are multiple
1248+
let not_null_expr: Arc<dyn PhysicalExpr> = children.iter().skip(1).fold(
1249+
Arc::new(IsNotNullExpr::new(children[0].clone())) as Arc<dyn PhysicalExpr>,
1250+
|acc, child| {
1251+
Arc::new(BinaryExpr::new(
1252+
acc,
1253+
DataFusionOperator::And,
1254+
Arc::new(IsNotNullExpr::new(child.clone())),
1255+
))
1256+
},
1257+
);
1258+
1259+
let child = Arc::new(IfExpr::new(
1260+
not_null_expr,
1261+
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
1262+
Arc::new(Literal::new(ScalarValue::Int64(Some(0)))),
1263+
));
1264+
12371265
create_aggregate_expr(
12381266
&sum_udaf(),
1239-
&[if_expr],
1267+
&[child],
12401268
&[],
12411269
&[],
12421270
&[],
@@ -1491,40 +1519,6 @@ impl PhysicalPlanner {
14911519
}
14921520
}
14931521

1494-
fn convert_count_to_if(
1495-
&self,
1496-
children: &[Expr],
1497-
schema: Arc<Schema>,
1498-
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
1499-
assert!(!children.is_empty(), "Children should not be empty");
1500-
1501-
// Translate `COUNT` to `SUM(IF(expr IS NOT NULL, 1, 0))`
1502-
let children_exprs = children
1503-
.iter()
1504-
.map(|child| self.create_expr(child, schema.clone()))
1505-
.collect::<Result<Vec<_>, _>>()?;
1506-
1507-
// Create `IS NOT NULL expr` and combine with `AND` for multiple children
1508-
let not_null_expr = children_exprs.iter().skip(1).fold(
1509-
Arc::new(IsNotNullExpr::new(children_exprs[0].clone())) as Arc<dyn PhysicalExpr>,
1510-
|acc, child| {
1511-
Arc::new(BinaryExpr::new(
1512-
acc,
1513-
DataFusionOperator::And,
1514-
Arc::new(IsNotNullExpr::new(child.clone())),
1515-
))
1516-
},
1517-
);
1518-
1519-
let if_expr = Arc::new(IfExpr::new(
1520-
not_null_expr,
1521-
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
1522-
Arc::new(Literal::new(ScalarValue::Int64(Some(0)))),
1523-
));
1524-
1525-
Ok(if_expr)
1526-
}
1527-
15281522
/// Create a DataFusion windows physical expression from Spark physical expression
15291523
fn create_window_expr<'a>(
15301524
&'a self,
@@ -1646,8 +1640,12 @@ impl PhysicalPlanner {
16461640
) -> Result<(String, Vec<Arc<dyn PhysicalExpr>>), ExecutionError> {
16471641
match &agg_func.expr_struct {
16481642
Some(AggExprStruct::Count(expr)) => {
1649-
let if_expr = self.convert_count_to_if(&expr.children, schema.clone())?;
1650-
Ok(("count".to_string(), vec![if_expr]))
1643+
let children = expr
1644+
.children
1645+
.iter()
1646+
.map(|child| self.create_expr(child, schema.clone()))
1647+
.collect::<Result<Vec<_>, _>>()?;
1648+
Ok(("count".to_string(), children))
16511649
}
16521650
Some(AggExprStruct::Min(expr)) => {
16531651
let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?;

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
209209
expr match {
210210
case agg: AggregateExpression =>
211211
agg.aggregateFunction match {
212-
case _: Min | _: Max | _: Count | _: Sum | _: Min | _: Max =>
212+
case _: Min | _: Max | _: Count | _: Sum =>
213213
Some(agg)
214214
case _ =>
215215
withInfo(windowExpr, "Unsupported aggregate", expr)

0 commit comments

Comments
 (0)