diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 5a29cf962817..979be16c4a21 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -27,13 +27,16 @@ use datafusion_common::cast::as_generic_list_array; use datafusion_common::utils::string_utils::string_array_to_vec; use datafusion_common::utils::take_function_args; use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::expr::{InList, ScalarFunction}; +use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; +use crate::make_array::make_array_udf; use crate::utils::make_scalar_function; use std::any::Any; @@ -121,6 +124,52 @@ impl ScalarUDFImpl for ArrayHas { Ok(DataType::Boolean) } + fn simplify( + &self, + mut args: Vec<Expr>, + _info: &dyn datafusion_expr::simplify::SimplifyInfo, + ) -> Result<ExprSimplifyResult> { + let [haystack, needle] = take_function_args(self.name(), &mut args)?; + + // if the haystack is a constant list, we can use an inlist expression which is more + // efficient because the haystack is not varying per-row + if let Expr::Literal(ScalarValue::List(array)) = haystack { + // TODO: support LargeList + // (not supported by `convert_array_to_scalar_vec`) + // (FixedSizeList not supported either, but seems to have worked fine when attempting to + // build a reproducer) + + assert_eq!(array.len(), 1); // guarantee of ScalarValue + if let Ok(scalar_values) = + ScalarValue::convert_array_to_scalar_vec(array.as_ref()) + { + assert_eq!(scalar_values.len(), 1); + let list = scalar_values + .into_iter() + .flatten() + .map(Expr::Literal) + .collect(); + + return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { + expr: Box::new(std::mem::take(needle)), + list, + negated: false, + }))); + } + } else if let Expr::ScalarFunction(ScalarFunction { func, args }) = haystack { + // make_array has a static set of arguments, so we can pull the arguments out from it + if func == &make_array_udf() { + return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { + expr: Box::new(std::mem::take(needle)), + list: std::mem::take(args), + negated: false, + }))); + } + } + + Ok(ExprSimplifyResult::Original(args)) + } + fn invoke_with_args( &self, args: datafusion_expr::ScalarFunctionArgs, @@ -535,3 +584,86 @@ fn general_array_has_all_and_any_kernel( }), } } + +#[cfg(test)] +mod tests { + use arrow::array::create_array; + use datafusion_common::utils::SingleRowListArrayBuilder; + use datafusion_expr::{ + col, execution_props::ExecutionProps, lit, simplify::ExprSimplifyResult, Expr, + ScalarUDFImpl, + }; + + use crate::expr_fn::make_array; + + use super::ArrayHas; + + #[test] + fn test_simplify_array_has_to_in_list() { + let haystack = lit(SingleRowListArrayBuilder::new(create_array!( + Int32, + [1, 2, 3] + )) + .build_list_scalar()); + let needle = col("c"); + + let props = ExecutionProps::new(); + let context = datafusion_expr::simplify::SimplifyContext::new(&props); + + let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) = + ArrayHas::new().simplify(vec![haystack, needle.clone()], &context) + else { + panic!("Expected simplified expression"); + }; + + assert_eq!( + in_list, + datafusion_expr::expr::InList { + expr: Box::new(needle), + list: vec![lit(1), lit(2), lit(3)], + negated: false, + } + ); + } + + #[test] + fn test_simplify_array_has_with_make_array_to_in_list() { + let haystack = make_array(vec![lit(1), lit(2), lit(3)]); + let needle = col("c"); + + let props = ExecutionProps::new(); + let context = datafusion_expr::simplify::SimplifyContext::new(&props); + + let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) = + ArrayHas::new().simplify(vec![haystack, needle.clone()], &context) + else { + panic!("Expected simplified expression"); + }; + + assert_eq!( + in_list, + datafusion_expr::expr::InList { + expr: Box::new(needle), + list: vec![lit(1), lit(2), lit(3)], + negated: false, + } + ); + } + + #[test] + fn test_array_has_complex_list_not_simplified() { + let haystack = col("c1"); + let needle = col("c2"); + + let props = ExecutionProps::new(); + let context = datafusion_expr::simplify::SimplifyContext::new(&props); + + let Ok(ExprSimplifyResult::Original(args)) = + ArrayHas::new().simplify(vec![haystack, needle.clone()], &context) + else { + panic!("Expected simplified expression"); + }; + + assert_eq!(args, vec![col("c1"), col("c2")],); + } +} diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 3b7f12960681..929a0408d62c 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5960,6 +5960,188 @@ true false true false false false true true false false true false true #---- #true false true false false false true true false false true false true +# rewrite various array_has operations to InList where the haystack is a literal list +# NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements, so we make 4-element haystack lists + +query I +with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'); +---- +1 + +query TT +explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'); +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: test +04)------SubqueryAlias: t +05)--------Projection: +06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: tmp_table projection=[value] +physical_plan +01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +05)--------ProjectionExec: expr=[] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] + +query I +with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE needle = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']); +---- +1 + +query TT +explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE needle = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']); +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: test +04)------SubqueryAlias: t +05)--------Projection: +06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: tmp_table projection=[value] +physical_plan +01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +05)--------ProjectionExec: expr=[] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] + +query I +with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], needle); +---- +1 + +query TT +explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], needle); +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: test +04)------SubqueryAlias: t +05)--------Projection: +06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: tmp_table projection=[value] +physical_plan +01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +05)--------ProjectionExec: expr=[] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] + +# FIXME: due to rewrite below not working, this is _extremely_ slow to evaluate +# query I +# with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +# select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle); +# ---- +# 1 + +# FIXME: array_has with large list haystack not currently rewritten to InList +query TT +explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle); +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: test +04)------SubqueryAlias: t +05)--------Projection: +06)----------Filter: array_has(LargeList([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32))) +07)------------TableScan: tmp_table projection=[value] +physical_plan +01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +05)--------ProjectionExec: expr=[] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------FilterExec: array_has([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c], substr(md5(CAST(value@0 AS Utf8)), 1, 32)) +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] + +query I +with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'FixedSizeList(4, Utf8View)'), needle); +---- +1 + +query TT +explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'FixedSizeList(4, Utf8View)'), needle); +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: test +04)------SubqueryAlias: t +05)--------Projection: +06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: tmp_table projection=[value] +physical_plan +01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +05)--------ProjectionExec: expr=[] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] + +query I +with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE array_has([needle], needle); +---- +100000 + +# TODO: this should probably be possible to completely remove the filter as always true? +query TT +explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE array_has([needle], needle); +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: test +04)------SubqueryAlias: t +05)--------Projection: +06)----------Filter: __common_expr_3 = __common_expr_3 +07)------------Projection: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) AS __common_expr_3 +08)--------------TableScan: tmp_table projection=[value] +physical_plan +01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +05)--------ProjectionExec: expr=[] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------FilterExec: __common_expr_3@0 = __common_expr_3@0 +08)--------------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8)), 1, 32) as __common_expr_3] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] + # any operator query ? select column3 from arrays where 'L'=any(column3);