Skip to content

Commit 00f5b7d

Browse files
authored
Support LargeList for array_sort (#17657)
1 parent 1488e10 commit 00f5b7d

File tree

2 files changed

+58
-19
lines changed

2 files changed

+58
-19
lines changed

datafusion/functions-nested/src/sort.rs

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
//! [`ScalarUDFImpl`] definitions for array_sort function.
1919
2020
use crate::utils::make_scalar_function;
21-
use arrow::array::{new_null_array, Array, ArrayRef, ListArray, NullBufferBuilder};
21+
use arrow::array::{
22+
new_null_array, Array, ArrayRef, GenericListArray, NullBufferBuilder, OffsetSizeTrait,
23+
};
2224
use arrow::buffer::OffsetBuffer;
2325
use arrow::compute::SortColumn;
24-
use arrow::datatypes::{DataType, Field};
26+
use arrow::datatypes::{DataType, FieldRef};
2527
use arrow::{compute, compute::SortOptions};
26-
use datafusion_common::cast::{as_list_array, as_string_array};
28+
use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array};
2729
use datafusion_common::utils::ListCoercion;
2830
use datafusion_common::{exec_err, plan_err, Result};
2931
use datafusion_expr::{
@@ -137,6 +139,9 @@ impl ScalarUDFImpl for ArraySort {
137139
DataType::List(field) => {
138140
Ok(DataType::new_list(field.data_type().clone(), true))
139141
}
142+
DataType::LargeList(field) => {
143+
Ok(DataType::new_large_list(field.data_type().clone(), true))
144+
}
140145
arg_type => {
141146
plan_err!("{} does not support type {arg_type}", self.name())
142147
}
@@ -165,21 +170,15 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
165170
return exec_err!("array_sort expects one to three arguments");
166171
}
167172

168-
if args[0].data_type().is_null() {
169-
return Ok(Arc::clone(&args[0]));
170-
}
171-
172-
let list_array = as_list_array(&args[0])?;
173-
let row_count = list_array.len();
174-
if row_count == 0 || list_array.value_type().is_null() {
173+
if args[0].is_empty() || args[0].data_type().is_null() {
175174
return Ok(Arc::clone(&args[0]));
176175
}
177176

178177
if args[1..].iter().any(|array| array.is_null(0)) {
179178
return Ok(new_null_array(args[0].data_type(), args[0].len()));
180179
}
181180

182-
let sort_option = match args.len() {
181+
let sort_options = match args.len() {
183182
1 => None,
184183
2 => {
185184
let sort = as_string_array(&args[1])?.value(0);
@@ -196,9 +195,37 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
196195
nulls_first: order_nulls_first(nulls_first)?,
197196
})
198197
}
199-
_ => return exec_err!("array_sort expects 1 to 3 arguments"),
198+
// We guard at the top
199+
_ => unreachable!(),
200200
};
201201

202+
match args[0].data_type() {
203+
DataType::List(field) | DataType::LargeList(field)
204+
if field.data_type().is_null() =>
205+
{
206+
Ok(Arc::clone(&args[0]))
207+
}
208+
DataType::List(field) => {
209+
let array = as_list_array(&args[0])?;
210+
array_sort_generic(array, field, sort_options)
211+
}
212+
DataType::LargeList(field) => {
213+
let array = as_large_list_array(&args[0])?;
214+
array_sort_generic(array, field, sort_options)
215+
}
216+
// Signature should prevent this arm ever occurring
217+
_ => exec_err!("array_sort expects list for first argument"),
218+
}
219+
}
220+
221+
/// Array_sort SQL function
222+
pub fn array_sort_generic<OffsetSize: OffsetSizeTrait>(
223+
list_array: &GenericListArray<OffsetSize>,
224+
field: &FieldRef,
225+
sort_options: Option<SortOptions>,
226+
) -> Result<ArrayRef> {
227+
let row_count = list_array.len();
228+
202229
let mut array_lengths = vec![];
203230
let mut arrays = vec![];
204231
let mut valid = NullBufferBuilder::new(row_count);
@@ -216,14 +243,14 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
216243
DataType::Struct(_) => {
217244
let sort_columns: Vec<SortColumn> = vec![SortColumn {
218245
values: Arc::clone(&arr_ref),
219-
options: sort_option,
246+
options: sort_options,
220247
}];
221248
let indices = compute::lexsort_to_indices(&sort_columns, None)?;
222249
compute::take(arr_ref.as_ref(), &indices, None)?
223250
}
224251
_ => {
225252
let arr_ref = arr_ref.as_ref();
226-
compute::sort(arr_ref, sort_option)?
253+
compute::sort(arr_ref, sort_options)?
227254
}
228255
};
229256
array_lengths.push(sorted_array.len());
@@ -232,8 +259,6 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
232259
}
233260
}
234261

235-
// Assume all arrays have the same data type
236-
let data_type = list_array.value_type();
237262
let buffer = valid.finish();
238263

239264
let elements = arrays
@@ -242,10 +267,10 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
242267
.collect::<Vec<&dyn Array>>();
243268

244269
let list_arr = if elements.is_empty() {
245-
ListArray::new_null(Arc::new(Field::new_list_field(data_type, true)), row_count)
270+
GenericListArray::<OffsetSize>::new_null(Arc::clone(field), row_count)
246271
} else {
247-
ListArray::new(
248-
Arc::new(Field::new_list_field(data_type, true)),
272+
GenericListArray::<OffsetSize>::new(
273+
Arc::clone(field),
249274
OffsetBuffer::from_lengths(array_lengths),
250275
Arc::new(compute::concat(elements.as_slice())?),
251276
buffer,

datafusion/sqllogictest/test_files/array.slt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2423,6 +2423,20 @@ select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1,
24232423
----
24242424
[NULL, NULL, -5, 1, 3, 5] [NULL, 1, 2, 3] [NULL, 3, 2, 1]
24252425

2426+
query ???
2427+
select array_sort(arrow_cast(make_array(1, 3, null, 5, NULL, -5), 'LargeList(Int64)')),
2428+
array_sort(arrow_cast(make_array(1, 3, null, 2), 'LargeList(Int64)'), 'ASC'),
2429+
array_sort(arrow_cast(make_array(1, 3, null, 2), 'LargeList(Int64)'), 'desc', 'NULLS FIRST');
2430+
----
2431+
[NULL, NULL, -5, 1, 3, 5] [NULL, 1, 2, 3] [NULL, 3, 2, 1]
2432+
2433+
query ???
2434+
select array_sort(arrow_cast(make_array(1, 3, null, 5, NULL, -5), 'FixedSizeList(6, Int64)')),
2435+
array_sort(arrow_cast(make_array(1, 3, null, 2), 'FixedSizeList(4, Int64)'), 'ASC'),
2436+
array_sort(arrow_cast(make_array(1, 3, null, 2), 'FixedSizeList(4, Int64)'), 'desc', 'NULLS FIRST');
2437+
----
2438+
[NULL, NULL, -5, 1, 3, 5] [NULL, 1, 2, 3] [NULL, 3, 2, 1]
2439+
24262440
query ?
24272441
select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values;
24282442
----

0 commit comments

Comments
 (0)