|
17 | 17 |
|
18 | 18 | //! This module provides a builder for creating LogicalPlans
|
19 | 19 |
|
| 20 | +use crate::expr::AggregateFunction; |
20 | 21 | use crate::expr_rewriter::{
|
21 | 22 | coerce_plan_expr_for_schema, normalize_col,
|
22 | 23 | normalize_col_with_schemas_and_ambiguity_check, normalize_cols,
|
23 | 24 | rewrite_sort_cols_by_aggs,
|
24 | 25 | };
|
25 | 26 | use crate::type_coercion::binary::comparison_coercion;
|
26 | 27 | use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan};
|
27 |
| -use crate::{and, binary_expr, Operator}; |
| 28 | +use crate::{aggregate_function, and, binary_expr, lit, Operator}; |
28 | 29 | use crate::{
|
29 | 30 | logical_plan::{
|
30 | 31 | Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join,
|
@@ -801,6 +802,43 @@ impl LogicalPlanBuilder {
|
801 | 802 | ) -> Result<Self> {
|
802 | 803 | let group_expr = normalize_cols(group_expr, &self.plan)?;
|
803 | 804 | let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
|
| 805 | + |
| 806 | + //handle Count(Expr:Wildcard) with DataFrame API |
| 807 | + let aggr_expr = aggr_expr |
| 808 | + .iter() |
| 809 | + .map(|expr| { |
| 810 | + if let Expr::AggregateFunction(AggregateFunction { |
| 811 | + fun, |
| 812 | + args, |
| 813 | + distinct, |
| 814 | + filter, |
| 815 | + }) = expr |
| 816 | + { |
| 817 | + if let aggregate_function::AggregateFunction::Count = fun { |
| 818 | + if args.len() == 1 { |
| 819 | + let arg = args.get(0).unwrap().clone(); |
| 820 | + match arg { |
| 821 | + Expr::Wildcard => { |
| 822 | + Expr::AggregateFunction(AggregateFunction { |
| 823 | + fun: fun.clone(), |
| 824 | + args: vec![lit(ScalarValue::UInt8(Some(1)))], |
| 825 | + distinct: *distinct, |
| 826 | + filter: filter.clone(), |
| 827 | + }) |
| 828 | + } |
| 829 | + _ => expr.clone(), |
| 830 | + } |
| 831 | + } else { |
| 832 | + expr.clone() |
| 833 | + } |
| 834 | + } else { |
| 835 | + expr.clone() |
| 836 | + } |
| 837 | + } else { |
| 838 | + expr.clone() |
| 839 | + } |
| 840 | + }) |
| 841 | + .collect(); |
804 | 842 | Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new(
|
805 | 843 | Arc::new(self.plan),
|
806 | 844 | group_expr,
|
@@ -1315,15 +1353,31 @@ pub fn unnest(input: LogicalPlan, column: Column) -> Result<LogicalPlan> {
|
1315 | 1353 |
|
1316 | 1354 | #[cfg(test)]
|
1317 | 1355 | mod tests {
|
1318 |
| - use crate::{expr, expr_fn::exists}; |
| 1356 | + use crate::{count, expr, expr_fn::exists}; |
1319 | 1357 | use arrow::datatypes::{DataType, Field};
|
1320 | 1358 | use datafusion_common::SchemaError;
|
1321 | 1359 |
|
1322 | 1360 | use crate::logical_plan::StringifiedPlan;
|
1323 | 1361 |
|
1324 | 1362 | use super::*;
|
| 1363 | + use crate::BuiltinScalarFunction::Exp; |
1325 | 1364 | use crate::{col, in_subquery, lit, scalar_subquery, sum};
|
1326 | 1365 |
|
| 1366 | + #[test] |
| 1367 | + fn count_wildcard() -> Result<()> { |
| 1368 | + let mut group_expr: Vec<Expr> = Vec::new(); |
| 1369 | + let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![]))? |
| 1370 | + .aggregate(group_expr, vec![count(Expr::Wildcard)])? |
| 1371 | + .build()?; |
| 1372 | + |
| 1373 | + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ |
| 1374 | + \n TableScan: employee_csv projection=[]"; |
| 1375 | + |
| 1376 | + assert_eq!(expected, format!("{plan:?}")); |
| 1377 | + |
| 1378 | + Ok(()) |
| 1379 | + } |
| 1380 | + |
1327 | 1381 | #[test]
|
1328 | 1382 | fn plan_builder_simple() -> Result<()> {
|
1329 | 1383 | let plan =
|
|
0 commit comments