Skip to content

Commit

Permalink
feat: Support count AggregateUDF for window function
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxingao committed Jul 29, 2024
1 parent b04baa5 commit 04ab458
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
28 changes: 25 additions & 3 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
//! Converts Spark physical plan to DataFusion physical plan
use std::{collections::HashMap, sync::Arc};
use std::str::FromStr;

use arrow_schema::{DataType, Field, Schema, TimeUnit};
use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf};
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::logical_expr::BuiltInWindowFunction;
use datafusion::physical_plan::windows::BoundedWindowAggExec;
use datafusion::physical_plan::InputOrderMode;
use datafusion::{
Expand Down Expand Up @@ -55,8 +57,7 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter},
JoinType as DFJoinType, ScalarValue,
};
use datafusion_expr::expr::find_df_window_func;
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use datafusion_expr::{aggregate_function, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition};
use datafusion_physical_expr::window::WindowExpr;
use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use itertools::Itertools;
Expand Down Expand Up @@ -1483,7 +1484,7 @@ impl PhysicalPlanner {
));
}

let window_func = match find_df_window_func(&window_func_name) {
let window_func = match Self::find_df_window_func(&window_func_name) {
Some(f) => f,
_ => {
return Err(ExecutionError::GeneralError(format!(
Expand Down Expand Up @@ -1599,6 +1600,27 @@ impl PhysicalPlanner {
}
}

/// Find DataFusion's built-in window function by name.
fn find_df_window_func(name: &str) -> Option<WindowFunctionDefinition> {
let name = name.to_lowercase();
if let Ok(built_in_function) =
BuiltInWindowFunction::from_str(name.as_str())
{
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_function,
))
} else if let Ok(aggregate) =
aggregate_function::AggregateFunction::from_str(name.as_str())
{
Some(WindowFunctionDefinition::AggregateFunction(aggregate))
} else {
match name.as_str() {
"count" => Some(WindowFunctionDefinition::AggregateUDF(count_udaf())),
_ => None,
}
}
}

/// Create a DataFusion physical partitioning from Spark physical partitioning
fn create_partitioning(
&self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
// TODO add support for Count (this was removed when upgrading
// to DataFusion 40 because it is no longer a built-in window function)
// https://github.com/apache/datafusion-comet/issues/645
case _: Min | _: Max =>
case _: Min | _: Max | _: Count =>
Some(agg)
case _ =>
withInfo(windowExpr, "Unsupported aggregate", expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ class CometExecSuite extends CometTestBase {
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) {
withParquetTable((0 until 10).map(i => (i, 10 - i)), "t1") { // TODO: test nulls
val aggregateFunctions =
List("MAX(_1)", "MIN(_1)") // TODO: Test all the aggregates
List("COUNT(_1)", "MAX(_1)", "MIN(_1)") // TODO: Test all the aggregates

aggregateFunctions.foreach { function =>
val queries = Seq(
Expand Down

0 comments on commit 04ab458

Please sign in to comment.