Skip to content

Commit 22f9e23

Browse files
committed
fix erro on Count(Expr:Wildcard) with DataFrame API
1 parent 146a949 commit 22f9e23

File tree

2 files changed

+121
-2
lines changed

2 files changed

+121
-2
lines changed

Diff for: datafusion/core/tests/dataframe.rs

+65
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use arrow::datatypes::{DataType, Field, Schema};
19+
use arrow::util::pretty::pretty_format_batches;
1920
use arrow::{
2021
array::{
2122
ArrayRef, Int32Array, Int32Builder, ListBuilder, StringArray, StringBuilder,
@@ -35,6 +36,70 @@ use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
3536
use datafusion_expr::expr::{GroupingSet, Sort};
3637
use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
3738

39+
#[tokio::test]
40+
async fn count_wildcard() -> Result<()> {
41+
let ctx = SessionContext::new();
42+
let testdata = datafusion::test_util::parquet_test_data();
43+
44+
ctx.register_parquet(
45+
"alltypes_tiny_pages",
46+
&format!("{testdata}/alltypes_tiny_pages.parquet"),
47+
ParquetReadOptions::default(),
48+
)
49+
.await?;
50+
51+
let expected = vec![
52+
"+---------------+---------------------------------------------------+",
53+
"| plan_type | plan |",
54+
"+---------------+---------------------------------------------------+",
55+
"| | |",
56+
"| | EmptyExec: produce_one_row=true |",
57+
"| | TableScan: alltypes_tiny_pages projection=[id] |",
58+
"| logical_plan | Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] |",
59+
"| physical_plan | ProjectionExec: expr=[7300 as COUNT(UInt8(1))] |",
60+
"+---------------+---------------------------------------------------+",
61+
];
62+
let sql_results = ctx
63+
.sql("select count(*) from alltypes_tiny_pages")
64+
.await?
65+
.explain(false, false)?
66+
.collect()
67+
.await?;
68+
69+
let df_results = ctx
70+
.table("alltypes_tiny_pages")
71+
.await?
72+
.aggregate(vec![], vec![count(Expr::Wildcard)])?
73+
.explain(false, false)
74+
.unwrap()
75+
.collect()
76+
.await?;
77+
78+
//make sure sql plan same with df plan
79+
assert_eq!(
80+
pretty_format_batches(&sql_results)?.to_string(),
81+
pretty_format_batches(&df_results)?.to_string()
82+
);
83+
84+
let results = ctx
85+
.table("alltypes_tiny_pages")
86+
.await?
87+
.aggregate(vec![], vec![count(Expr::Wildcard)])?
88+
.collect()
89+
.await?;
90+
91+
let expected = vec![
92+
"+-----------------+",
93+
"| COUNT(UInt8(1)) |",
94+
"+-----------------+",
95+
"| 7300 |",
96+
"+-----------------+",
97+
];
98+
assert_batches_sorted_eq!(expected, &results);
99+
100+
Ok(())
101+
}
102+
38103
#[tokio::test]
39104
async fn describe() -> Result<()> {
40105
let ctx = SessionContext::new();

Diff for: datafusion/expr/src/logical_plan/builder.rs

+56-2
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
//! This module provides a builder for creating LogicalPlans
1919
20+
use crate::expr::AggregateFunction;
2021
use crate::expr_rewriter::{
2122
coerce_plan_expr_for_schema, normalize_col,
2223
normalize_col_with_schemas_and_ambiguity_check, normalize_cols,
2324
rewrite_sort_cols_by_aggs,
2425
};
2526
use crate::type_coercion::binary::comparison_coercion;
2627
use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan};
27-
use crate::{and, binary_expr, Operator};
28+
use crate::{aggregate_function, and, binary_expr, lit, Operator};
2829
use crate::{
2930
logical_plan::{
3031
Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join,
@@ -801,6 +802,43 @@ impl LogicalPlanBuilder {
801802
) -> Result<Self> {
802803
let group_expr = normalize_cols(group_expr, &self.plan)?;
803804
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
805+
806+
//handle Count(Expr:Wildcard) with DataFrame API
807+
let aggr_expr = aggr_expr
808+
.iter()
809+
.map(|expr| {
810+
if let Expr::AggregateFunction(AggregateFunction {
811+
fun,
812+
args,
813+
distinct,
814+
filter,
815+
}) = expr
816+
{
817+
if let aggregate_function::AggregateFunction::Count = fun {
818+
if args.len() == 1 {
819+
let arg = args.get(0).unwrap().clone();
820+
match arg {
821+
Expr::Wildcard => {
822+
Expr::AggregateFunction(AggregateFunction {
823+
fun: fun.clone(),
824+
args: vec![lit(ScalarValue::UInt8(Some(1)))],
825+
distinct: *distinct,
826+
filter: filter.clone(),
827+
})
828+
}
829+
_ => expr.clone(),
830+
}
831+
} else {
832+
expr.clone()
833+
}
834+
} else {
835+
expr.clone()
836+
}
837+
} else {
838+
expr.clone()
839+
}
840+
})
841+
.collect();
804842
Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new(
805843
Arc::new(self.plan),
806844
group_expr,
@@ -1315,15 +1353,31 @@ pub fn unnest(input: LogicalPlan, column: Column) -> Result<LogicalPlan> {
13151353

13161354
#[cfg(test)]
13171355
mod tests {
1318-
use crate::{expr, expr_fn::exists};
1356+
use crate::{count, expr, expr_fn::exists};
13191357
use arrow::datatypes::{DataType, Field};
13201358
use datafusion_common::SchemaError;
13211359

13221360
use crate::logical_plan::StringifiedPlan;
13231361

13241362
use super::*;
1363+
use crate::BuiltinScalarFunction::Exp;
13251364
use crate::{col, in_subquery, lit, scalar_subquery, sum};
13261365

1366+
#[test]
1367+
fn count_wildcard() -> Result<()> {
1368+
let mut group_expr: Vec<Expr> = Vec::new();
1369+
let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![]))?
1370+
.aggregate(group_expr, vec![count(Expr::Wildcard)])?
1371+
.build()?;
1372+
1373+
let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
1374+
\n TableScan: employee_csv projection=[]";
1375+
1376+
assert_eq!(expected, format!("{plan:?}"));
1377+
1378+
Ok(())
1379+
}
1380+
13271381
#[test]
13281382
fn plan_builder_simple() -> Result<()> {
13291383
let plan =

0 commit comments

Comments
 (0)