Skip to content

Commit c2d839f

Browse files
authored
chore: refactor array fn signatures & add more slt tests (#17672)
1 parent d96fbde commit c2d839f

File tree

8 files changed

+216
-37
lines changed

8 files changed

+216
-37
lines changed

datafusion/functions-nested/src/array_has.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ impl Default for ArrayHasAll {
497497
impl ArrayHasAll {
498498
pub fn new() -> Self {
499499
Self {
500-
signature: Signature::any(2, Volatility::Immutable),
500+
signature: Signature::arrays(2, None, Volatility::Immutable),
501501
aliases: vec![String::from("list_has_all")],
502502
}
503503
}
@@ -571,7 +571,7 @@ impl Default for ArrayHasAny {
571571
impl ArrayHasAny {
572572
pub fn new() -> Self {
573573
Self {
574-
signature: Signature::any(2, Volatility::Immutable),
574+
signature: Signature::arrays(2, None, Volatility::Immutable),
575575
aliases: vec![String::from("list_has_any"), String::from("arrays_overlap")],
576576
}
577577
}

datafusion/functions-nested/src/cardinality.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ impl Cardinality {
5858
],
5959
Volatility::Immutable,
6060
),
61-
aliases: vec![],
6261
}
6362
}
6463
}
@@ -83,7 +82,6 @@ impl Cardinality {
8382
#[derive(Debug, PartialEq, Eq, Hash)]
8483
pub struct Cardinality {
8584
signature: Signature,
86-
aliases: Vec<String>,
8785
}
8886

8987
impl Default for Cardinality {
@@ -114,10 +112,6 @@ impl ScalarUDFImpl for Cardinality {
114112
make_scalar_function(cardinality_inner)(&args.args)
115113
}
116114

117-
fn aliases(&self) -> &[String] {
118-
&self.aliases
119-
}
120-
121115
fn documentation(&self) -> Option<&Documentation> {
122116
self.doc()
123117
}

datafusion/functions-nested/src/flatten.rs

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ use arrow::datatypes::{
2525
DataType::{FixedSizeList, LargeList, List, Null},
2626
};
2727
use datafusion_common::cast::{as_large_list_array, as_list_array};
28-
use datafusion_common::utils::ListCoercion;
2928
use datafusion_common::{exec_err, utils::take_function_args, Result};
3029
use datafusion_expr::{
31-
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
32-
ScalarUDFImpl, Signature, TypeSignature, Volatility,
30+
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
3331
};
3432
use datafusion_macros::user_doc;
3533
use std::any::Any;
@@ -75,15 +73,7 @@ impl Default for Flatten {
7573
impl Flatten {
7674
pub fn new() -> Self {
7775
Self {
78-
signature: Signature {
79-
type_signature: TypeSignature::ArraySignature(
80-
ArrayFunctionSignature::Array {
81-
arguments: vec![ArrayFunctionArgument::Array],
82-
array_coercion: Some(ListCoercion::FixedSizedListToList),
83-
},
84-
),
85-
volatility: Volatility::Immutable,
86-
},
76+
signature: Signature::array(Volatility::Immutable),
8777
aliases: vec![],
8878
}
8979
}
@@ -104,7 +94,7 @@ impl ScalarUDFImpl for Flatten {
10494

10595
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
10696
let data_type = match &arg_types[0] {
107-
List(field) | FixedSizeList(field, _) => match field.data_type() {
97+
List(field) => match field.data_type() {
10898
List(field) | FixedSizeList(field, _) => List(Arc::clone(field)),
10999
_ => arg_types[0].clone(),
110100
},

datafusion/functions-nested/src/length.rs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ use arrow::datatypes::{
2929
use datafusion_common::cast::{
3030
as_fixed_size_list_array, as_generic_list_array, as_int64_array,
3131
};
32-
use datafusion_common::{exec_err, plan_err, Result};
32+
use datafusion_common::{exec_err, Result};
3333
use datafusion_expr::{
34-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
34+
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
35+
ScalarUDFImpl, Signature, TypeSignature, Volatility,
3536
};
3637
use datafusion_functions::downcast_arg;
3738
use datafusion_macros::user_doc;
@@ -79,7 +80,22 @@ impl Default for ArrayLength {
7980
impl ArrayLength {
8081
pub fn new() -> Self {
8182
Self {
82-
signature: Signature::variadic_any(Volatility::Immutable),
83+
signature: Signature::one_of(
84+
vec![
85+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
86+
arguments: vec![ArrayFunctionArgument::Array],
87+
array_coercion: None,
88+
}),
89+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
90+
arguments: vec![
91+
ArrayFunctionArgument::Array,
92+
ArrayFunctionArgument::Index,
93+
],
94+
array_coercion: None,
95+
}),
96+
],
97+
Volatility::Immutable,
98+
),
8399
aliases: vec![String::from("list_length")],
84100
}
85101
}
@@ -97,13 +113,8 @@ impl ScalarUDFImpl for ArrayLength {
97113
&self.signature
98114
}
99115

100-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
101-
Ok(match arg_types[0] {
102-
List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64,
103-
_ => {
104-
return plan_err!("The array_length function can only accept List/LargeList/FixedSizeList.");
105-
}
106-
})
116+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
117+
Ok(UInt64)
107118
}
108119

109120
fn invoke_with_args(

datafusion/functions-nested/src/remove.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ use arrow::array::{
2626
use arrow::buffer::OffsetBuffer;
2727
use arrow::datatypes::{DataType, Field};
2828
use datafusion_common::cast::as_int64_array;
29+
use datafusion_common::utils::ListCoercion;
2930
use datafusion_common::{exec_err, utils::take_function_args, Result};
3031
use datafusion_expr::{
31-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
32+
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33+
ScalarUDFImpl, Signature, TypeSignature, Volatility,
3234
};
3335
use datafusion_macros::user_doc;
3436
use std::any::Any;
@@ -156,7 +158,17 @@ pub(super) struct ArrayRemoveN {
156158
impl ArrayRemoveN {
157159
pub fn new() -> Self {
158160
Self {
159-
signature: Signature::any(3, Volatility::Immutable),
161+
signature: Signature::new(
162+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
163+
arguments: vec![
164+
ArrayFunctionArgument::Array,
165+
ArrayFunctionArgument::Element,
166+
ArrayFunctionArgument::Index,
167+
],
168+
array_coercion: Some(ListCoercion::FixedSizedListToList),
169+
}),
170+
Volatility::Immutable,
171+
),
160172
aliases: vec!["list_remove_n".to_string()],
161173
}
162174
}

datafusion/functions-nested/src/resize.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use arrow::buffer::OffsetBuffer;
2626
use arrow::datatypes::DataType;
2727
use arrow::datatypes::{ArrowNativeType, Field};
2828
use arrow::datatypes::{
29-
DataType::{FixedSizeList, LargeList, List},
29+
DataType::{LargeList, List},
3030
FieldRef,
3131
};
3232
use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
@@ -125,7 +125,7 @@ impl ScalarUDFImpl for ArrayResize {
125125

126126
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
127127
match &arg_types[0] {
128-
List(field) | FixedSizeList(field, _) => Ok(List(Arc::clone(field))),
128+
List(field) => Ok(List(Arc::clone(field))),
129129
LargeList(field) => Ok(LargeList(Arc::clone(field))),
130130
DataType::Null => {
131131
Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true))))

datafusion/functions-nested/src/reverse.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ impl Default for ArrayReverse {
7676
impl ArrayReverse {
7777
pub fn new() -> Self {
7878
Self {
79-
signature: Signature::any(1, Volatility::Immutable),
79+
signature: Signature::array(Volatility::Immutable),
8080
aliases: vec!["list_reverse".to_string()],
8181
}
8282
}

0 commit comments

Comments
 (0)