1818//! [`ScalarUDFImpl`] definitions for array_sort function.
1919
2020use crate :: utils:: make_scalar_function;
21- use arrow:: array:: { new_null_array, Array , ArrayRef , ListArray , NullBufferBuilder } ;
21+ use arrow:: array:: {
22+ new_null_array, Array , ArrayRef , GenericListArray , NullBufferBuilder , OffsetSizeTrait ,
23+ } ;
2224use arrow:: buffer:: OffsetBuffer ;
2325use arrow:: compute:: SortColumn ;
24- use arrow:: datatypes:: { DataType , Field } ;
26+ use arrow:: datatypes:: { DataType , FieldRef } ;
2527use arrow:: { compute, compute:: SortOptions } ;
26- use datafusion_common:: cast:: { as_list_array, as_string_array} ;
28+ use datafusion_common:: cast:: { as_large_list_array , as_list_array, as_string_array} ;
2729use datafusion_common:: utils:: ListCoercion ;
2830use datafusion_common:: { exec_err, plan_err, Result } ;
2931use datafusion_expr:: {
@@ -137,6 +139,9 @@ impl ScalarUDFImpl for ArraySort {
137139 DataType :: List ( field) => {
138140 Ok ( DataType :: new_list ( field. data_type ( ) . clone ( ) , true ) )
139141 }
142+ DataType :: LargeList ( field) => {
143+ Ok ( DataType :: new_large_list ( field. data_type ( ) . clone ( ) , true ) )
144+ }
140145 arg_type => {
141146 plan_err ! ( "{} does not support type {arg_type}" , self . name( ) )
142147 }
@@ -165,21 +170,15 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
165170 return exec_err ! ( "array_sort expects one to three arguments" ) ;
166171 }
167172
168- if args[ 0 ] . data_type ( ) . is_null ( ) {
169- return Ok ( Arc :: clone ( & args[ 0 ] ) ) ;
170- }
171-
172- let list_array = as_list_array ( & args[ 0 ] ) ?;
173- let row_count = list_array. len ( ) ;
174- if row_count == 0 || list_array. value_type ( ) . is_null ( ) {
173+ if args[ 0 ] . is_empty ( ) || args[ 0 ] . data_type ( ) . is_null ( ) {
175174 return Ok ( Arc :: clone ( & args[ 0 ] ) ) ;
176175 }
177176
178177 if args[ 1 ..] . iter ( ) . any ( |array| array. is_null ( 0 ) ) {
179178 return Ok ( new_null_array ( args[ 0 ] . data_type ( ) , args[ 0 ] . len ( ) ) ) ;
180179 }
181180
182- let sort_option = match args. len ( ) {
181+ let sort_options = match args. len ( ) {
183182 1 => None ,
184183 2 => {
185184 let sort = as_string_array ( & args[ 1 ] ) ?. value ( 0 ) ;
@@ -196,9 +195,37 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
196195 nulls_first : order_nulls_first ( nulls_first) ?,
197196 } )
198197 }
199- _ => return exec_err ! ( "array_sort expects 1 to 3 arguments" ) ,
198+ // We guard at the top
199+ _ => unreachable ! ( ) ,
200200 } ;
201201
202+ match args[ 0 ] . data_type ( ) {
203+ DataType :: List ( field) | DataType :: LargeList ( field)
204+ if field. data_type ( ) . is_null ( ) =>
205+ {
206+ Ok ( Arc :: clone ( & args[ 0 ] ) )
207+ }
208+ DataType :: List ( field) => {
209+ let array = as_list_array ( & args[ 0 ] ) ?;
210+ array_sort_generic ( array, field, sort_options)
211+ }
212+ DataType :: LargeList ( field) => {
213+ let array = as_large_list_array ( & args[ 0 ] ) ?;
214+ array_sort_generic ( array, field, sort_options)
215+ }
216+ // Signature should prevent this arm ever occurring
217+ _ => exec_err ! ( "array_sort expects list for first argument" ) ,
218+ }
219+ }
220+
221+ /// Array_sort SQL function
222+ pub fn array_sort_generic < OffsetSize : OffsetSizeTrait > (
223+ list_array : & GenericListArray < OffsetSize > ,
224+ field : & FieldRef ,
225+ sort_options : Option < SortOptions > ,
226+ ) -> Result < ArrayRef > {
227+ let row_count = list_array. len ( ) ;
228+
202229 let mut array_lengths = vec ! [ ] ;
203230 let mut arrays = vec ! [ ] ;
204231 let mut valid = NullBufferBuilder :: new ( row_count) ;
@@ -216,14 +243,14 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
216243 DataType :: Struct ( _) => {
217244 let sort_columns: Vec < SortColumn > = vec ! [ SortColumn {
218245 values: Arc :: clone( & arr_ref) ,
219- options: sort_option ,
246+ options: sort_options ,
220247 } ] ;
221248 let indices = compute:: lexsort_to_indices ( & sort_columns, None ) ?;
222249 compute:: take ( arr_ref. as_ref ( ) , & indices, None ) ?
223250 }
224251 _ => {
225252 let arr_ref = arr_ref. as_ref ( ) ;
226- compute:: sort ( arr_ref, sort_option ) ?
253+ compute:: sort ( arr_ref, sort_options ) ?
227254 }
228255 } ;
229256 array_lengths. push ( sorted_array. len ( ) ) ;
@@ -232,8 +259,6 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
232259 }
233260 }
234261
235- // Assume all arrays have the same data type
236- let data_type = list_array. value_type ( ) ;
237262 let buffer = valid. finish ( ) ;
238263
239264 let elements = arrays
@@ -242,10 +267,10 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
242267 . collect :: < Vec < & dyn Array > > ( ) ;
243268
244269 let list_arr = if elements. is_empty ( ) {
245- ListArray :: new_null ( Arc :: new ( Field :: new_list_field ( data_type , true ) ) , row_count)
270+ GenericListArray :: < OffsetSize > :: new_null ( Arc :: clone ( field ) , row_count)
246271 } else {
247- ListArray :: new (
248- Arc :: new ( Field :: new_list_field ( data_type , true ) ) ,
272+ GenericListArray :: < OffsetSize > :: new (
273+ Arc :: clone ( field ) ,
249274 OffsetBuffer :: from_lengths ( array_lengths) ,
250275 Arc :: new ( compute:: concat ( elements. as_slice ( ) ) ?) ,
251276 buffer,
0 commit comments