Skip to content

Commit 2f0ae2a

Browse files
committed
fix erro on Count(Expr:Wildcard) with DataFrame API
1 parent 140ce5e commit 2f0ae2a

File tree

2 files changed

+56
-45
lines changed

2 files changed

+56
-45
lines changed

datafusion/core/tests/dataframe.rs

+22-44
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use datafusion_expr::expr::{GroupingSet, Sort};
3636
use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
3737

3838
#[tokio::test]
39-
async fn count_star() -> Result<()> {
39+
async fn count_wildcard() -> Result<()> {
4040
let ctx = SessionContext::new();
4141

4242
let testdata = datafusion::test_util::parquet_test_data();
@@ -48,60 +48,38 @@ async fn count_star() -> Result<()> {
4848
.await?;
4949

5050
let results = df
51+
.clone()
5152
.aggregate(vec![], vec![count(Expr::Wildcard)])?
53+
.explain(false, false)
54+
.unwrap()
5255
.collect()
5356
.await?;
5457

5558
let expected = vec![
56-
"+-----------------+",
57-
"| COUNT(UInt8(1)) |",
58-
"+-----------------+",
59-
"| 7300 |",
60-
"+-----------------+",
59+
"+---------------+---------------------------------------------------+",
60+
"| plan_type | plan |",
61+
"+---------------+---------------------------------------------------+",
62+
"| | |",
63+
"| | EmptyExec: produce_one_row=true |",
64+
"| | TableScan: ?table? projection=[id] |",
65+
"| logical_plan | Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] |",
66+
"| physical_plan | ProjectionExec: expr=[7300 as COUNT(UInt8(1))] |",
67+
"+---------------+---------------------------------------------------+",
6168
];
6269
assert_batches_sorted_eq!(expected, &results);
6370

64-
Ok(())
65-
}
66-
67-
#[tokio::test]
68-
async fn count_test() -> Result<()> {
69-
let mut ctx = SessionContext::new();
70-
let disable_aggregate_statistics = true;
71-
72-
//disable or not,this test case should be passed
73-
if disable_aggregate_statistics {
74-
let with_out_aggregate_statistics = ctx
75-
.state()
76-
.physical_optimizers()
77-
.iter()
78-
.filter(|optimizer| optimizer.as_ref().name().ne("aggregate_statistics"))
79-
.map(|optimizer| optimizer.clone())
80-
.collect();
81-
let state = ctx
82-
.state()
83-
.with_physical_optimizer_rules(with_out_aggregate_statistics);
84-
ctx = SessionContext::with_state(state);
85-
}
86-
87-
let testdata = datafusion::test_util::parquet_test_data();
88-
ctx.register_parquet(
89-
"alltypes_tiny_pages",
90-
&format!("{testdata}/alltypes_tiny_pages.parquet"),
91-
ParquetReadOptions::default(),
92-
)
93-
.await?;
94-
let results = ctx
95-
.sql("SELECT count(id), max(id), min(id) FROM alltypes_tiny_pages")
96-
.await?
71+
let results = df
72+
.clone()
73+
.aggregate(vec![], vec![count(Expr::Wildcard)])?
9774
.collect()
9875
.await?;
76+
9977
let expected = vec![
100-
"+-------------------------------+-----------------------------+-----------------------------+",
101-
"| COUNT(alltypes_tiny_pages.id) | MAX(alltypes_tiny_pages.id) | MIN(alltypes_tiny_pages.id) |",
102-
"+-------------------------------+-----------------------------+-----------------------------+",
103-
"| 7300 | 7299 | 0 |",
104-
"+-------------------------------+-----------------------------+-----------------------------+",
78+
"+-----------------+",
79+
"| COUNT(UInt8(1)) |",
80+
"+-----------------+",
81+
"| 7300 |",
82+
"+-----------------+",
10583
];
10684
assert_batches_sorted_eq!(expected, &results);
10785

datafusion/expr/src/logical_plan/builder.rs

+34-1
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
//! This module provides a builder for creating LogicalPlans
1919
20+
use crate::expr::AggregateFunction;
2021
use crate::expr_rewriter::{
2122
coerce_plan_expr_for_schema, normalize_col,
2223
normalize_col_with_schemas_and_ambiguity_check, normalize_cols,
2324
rewrite_sort_cols_by_aggs,
2425
};
2526
use crate::type_coercion::binary::comparison_coercion;
2627
use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan};
27-
use crate::{and, binary_expr, Operator};
28+
use crate::{aggregate_function, and, binary_expr, lit, Operator};
2829
use crate::{
2930
logical_plan::{
3031
Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join,
@@ -801,6 +802,38 @@ impl LogicalPlanBuilder {
801802
) -> Result<Self> {
802803
let group_expr = normalize_cols(group_expr, &self.plan)?;
803804
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
805+
806+
let aggr_expr = aggr_expr
807+
.iter()
808+
.map(|expr| {
809+
if let Expr::AggregateFunction(AggregateFunction {
810+
fun,
811+
args,
812+
distinct,
813+
filter,
814+
}) = expr
815+
{
816+
if let aggregate_function::AggregateFunction::Count = fun {
817+
let w = args.get(0).unwrap();
818+
match w {
819+
Expr::Wildcard => {
820+
Expr::AggregateFunction(AggregateFunction {
821+
fun: fun.clone(),
822+
args: vec![lit(ScalarValue::UInt8(Some(1)))],
823+
distinct: distinct.clone(),
824+
filter: filter.clone(),
825+
})
826+
}
827+
_ => expr.clone(),
828+
}
829+
} else {
830+
expr.clone()
831+
}
832+
} else {
833+
expr.clone()
834+
}
835+
})
836+
.collect();
804837
Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new(
805838
Arc::new(self.plan),
806839
group_expr,

0 commit comments

Comments
 (0)