Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify array_has UDF to InList expr when haystack is constant #15354

Merged
merged 4 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion datafusion/functions-nested/src/array_has.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ use datafusion_common::cast::as_generic_list_array;
use datafusion_common::utils::string_utils::string_array_to_vec;
use datafusion_common::utils::take_function_args;
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::expr::{InList, ScalarFunction};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use datafusion_physical_expr_common::datum::compare_with_eq;
use itertools::Itertools;

use crate::make_array::make_array_udf;
use crate::utils::make_scalar_function;

use std::any::Any;
Expand Down Expand Up @@ -121,6 +124,52 @@ impl ScalarUDFImpl for ArrayHas {
Ok(DataType::Boolean)
}

fn simplify(
&self,
mut args: Vec<Expr>,
_info: &dyn datafusion_expr::simplify::SimplifyInfo,
) -> Result<ExprSimplifyResult> {
let [haystack, needle] = take_function_args(self.name(), &mut args)?;

// if the haystack is a constant list, we can use an inlist expression which is more
// efficient because the haystack is not varying per-row
if let Expr::Literal(ScalarValue::List(array)) = haystack {
// TODO: support LargeList
// (not supported by `convert_array_to_scalar_vec`)
// (FixedSizeList not supported either, but seems to have worked fine when attempting to
// build a reproducer)

assert_eq!(array.len(), 1); // guarantee of ScalarValue
if let Ok(scalar_values) =
ScalarValue::convert_array_to_scalar_vec(array.as_ref())
{
assert_eq!(scalar_values.len(), 1);
let list = scalar_values
.into_iter()
.flatten()
.map(Expr::Literal)
.collect();

return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
expr: Box::new(std::mem::take(needle)),
list,
negated: false,
})));
}
} else if let Expr::ScalarFunction(ScalarFunction { func, args }) = haystack {
// make_array has a static set of arguments, so we can pull the arguments out from it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect that during constant evaluation make_array would be turned into a literal so this case would be unecessary

However, you wouldn't observe that simplification happening in unit tests (only in the slt tests when everything was put together)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested removing this case and the slt tests failed like this

Completed 113 test files in 3 seconds                                                                                                                                 External error: query result mismatch:
[SQL] explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has([needle], needle);
[Diff] (-expected|+actual)
    logical_plan
    01)Projection: count(Int64(1)) AS count(*)
    02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
    03)----SubqueryAlias: test
    04)------SubqueryAlias: t
-   05)--------Projection:
-   06)----------Filter: __common_expr_3 = __common_expr_3
+   05)--------Projection:
+   06)----------Filter: array_has(make_array(__common_expr_3), __common_expr_3)
    07)------------Projection: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) AS __common_expr_3
    08)--------------TableScan: tmp_table projection=[value]
    physical_plan
    01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
    02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
    03)----CoalescePartitionsExec
    04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
    05)--------ProjectionExec: expr=[]
    06)----------CoalesceBatchesExec: target_batch_size=8192
-   07)------------FilterExec: __common_expr_3@0 = __common_expr_3@0
+   07)------------FilterExec: array_has(make_array(__common_expr_3@0), __common_expr_3@0)
    08)--------------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8)), 1, 32) as __common_expr_3]
    09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
    10)------------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
at test_files/array.slt:6120

I found that unexpected but don't have time to look into it more now

if func == &make_array_udf() {
return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
expr: Box::new(std::mem::take(needle)),
list: std::mem::take(args),
negated: false,
})));
}
}

Ok(ExprSimplifyResult::Original(args))
}

fn invoke_with_args(
&self,
args: datafusion_expr::ScalarFunctionArgs,
Expand Down Expand Up @@ -535,3 +584,86 @@ fn general_array_has_all_and_any_kernel(
}),
}
}

#[cfg(test)]
mod tests {
use arrow::array::create_array;
use datafusion_common::utils::SingleRowListArrayBuilder;
use datafusion_expr::{
col, execution_props::ExecutionProps, lit, simplify::ExprSimplifyResult, Expr,
ScalarUDFImpl,
};

use crate::expr_fn::make_array;

use super::ArrayHas;

#[test]
fn test_simplify_array_has_to_in_list() {
let haystack = lit(SingleRowListArrayBuilder::new(create_array!(
Int32,
[1, 2, 3]
))
.build_list_scalar());
let needle = col("c");

let props = ExecutionProps::new();
let context = datafusion_expr::simplify::SimplifyContext::new(&props);

let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) =
ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
else {
panic!("Expected simplified expression");
};

assert_eq!(
in_list,
datafusion_expr::expr::InList {
expr: Box::new(needle),
list: vec![lit(1), lit(2), lit(3)],
negated: false,
}
);
}

#[test]
fn test_simplify_array_has_with_make_array_to_in_list() {
let haystack = make_array(vec![lit(1), lit(2), lit(3)]);
let needle = col("c");

let props = ExecutionProps::new();
let context = datafusion_expr::simplify::SimplifyContext::new(&props);

let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) =
ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
else {
panic!("Expected simplified expression");
};

assert_eq!(
in_list,
datafusion_expr::expr::InList {
expr: Box::new(needle),
list: vec![lit(1), lit(2), lit(3)],
negated: false,
}
);
}

#[test]
fn test_array_has_complex_list_not_simplified() {
let haystack = col("c1");
let needle = col("c2");

let props = ExecutionProps::new();
let context = datafusion_expr::simplify::SimplifyContext::new(&props);

let Ok(ExprSimplifyResult::Original(args)) =
ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
else {
panic!("Expected simplified expression");
};

assert_eq!(args, vec![col("c1"), col("c2")],);
}
}
182 changes: 182 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5960,6 +5960,188 @@ true false true false false false true true false false true false true
#----
#true false true false false false true true false false true false true

# rewrite various array_has operations to InList where the haystack is a literal list
# NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements, so we make 4-element haystack lists
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


query I
with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c');
----
1

query TT
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c');
----
logical_plan
01)Projection: count(Int64(1)) AS count(*)
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
03)----SubqueryAlias: test
04)------SubqueryAlias: t
05)--------Projection:
06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
07)------------TableScan: tmp_table projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------ProjectionExec: expr=[]
06)----------CoalesceBatchesExec: target_batch_size=8192
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }])
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]

query I
with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE needle = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']);
----
1

query TT
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE needle = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']);
----
logical_plan
01)Projection: count(Int64(1)) AS count(*)
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
03)----SubqueryAlias: test
04)------SubqueryAlias: t
05)--------Projection:
06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is cool to see

07)------------TableScan: tmp_table projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------ProjectionExec: expr=[]
06)----------CoalesceBatchesExec: target_batch_size=8192
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }])
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]

query I
with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], needle);
----
1

query TT
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], needle);
----
logical_plan
01)Projection: count(Int64(1)) AS count(*)
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
03)----SubqueryAlias: test
04)------SubqueryAlias: t
05)--------Projection:
06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
07)------------TableScan: tmp_table projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------ProjectionExec: expr=[]
06)----------CoalesceBatchesExec: target_batch_size=8192
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }])
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]

# FIXME: due to rewrite below not working, this is _extremely_ slow to evaluate
# query I
# with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
# select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle);
# ----
# 1

# FIXME: array_has with large list haystack not currently rewritten to InList
query TT
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle);
----
logical_plan
01)Projection: count(Int64(1)) AS count(*)
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
03)----SubqueryAlias: test
04)------SubqueryAlias: t
05)--------Projection:
06)----------Filter: array_has(LargeList([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)))
07)------------TableScan: tmp_table projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------ProjectionExec: expr=[]
06)----------CoalesceBatchesExec: target_batch_size=8192
07)------------FilterExec: array_has([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c], substr(md5(CAST(value@0 AS Utf8)), 1, 32))
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]

query I
with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'FixedSizeList(4, Utf8View)'), needle);
----
1

query TT
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'FixedSizeList(4, Utf8View)'), needle);
----
logical_plan
01)Projection: count(Int64(1)) AS count(*)
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
03)----SubqueryAlias: test
04)------SubqueryAlias: t
05)--------Projection:
06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
07)------------TableScan: tmp_table projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------ProjectionExec: expr=[]
06)----------CoalesceBatchesExec: target_batch_size=8192
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }])
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]

query I
with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has([needle], needle);
----
100000

# TODO: this should probably be possible to completely remove the filter as always true?
query TT
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has([needle], needle);
----
logical_plan
01)Projection: count(Int64(1)) AS count(*)
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
03)----SubqueryAlias: test
04)------SubqueryAlias: t
05)--------Projection:
06)----------Filter: __common_expr_3 = __common_expr_3
07)------------Projection: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) AS __common_expr_3
08)--------------TableScan: tmp_table projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------ProjectionExec: expr=[]
06)----------CoalesceBatchesExec: target_batch_size=8192
07)------------FilterExec: __common_expr_3@0 = __common_expr_3@0
08)--------------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8)), 1, 32) as __common_expr_3]
09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
10)------------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
Comment on lines +6119 to +6143
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably #15387


# any operator
query ?
select column3 from arrays where 'L'=any(column3);
Expand Down