|
15 | 15 | // specific language governing permissions and limitations
|
16 | 16 | // under the License.
|
17 | 17 |
|
18 |
| -use std::any::Any; |
19 |
| -use std::sync::Arc; |
20 |
| - |
21 |
| -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; |
22 |
| -use arrow::datatypes::DataType; |
23 |
| -use datafusion_common::cast::{ |
24 |
| - as_generic_string_array, as_int64_array, as_string_view_array, |
25 |
| -}; |
26 |
| -use unicode_segmentation::UnicodeSegmentation; |
27 |
| - |
| 18 | +use crate::string::common::StringArrayType; |
28 | 19 | use crate::utils::{make_scalar_function, utf8_to_str_type};
|
| 20 | +use arrow::array::{ |
| 21 | + ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, |
| 22 | + OffsetSizeTrait, StringViewArray, |
| 23 | +}; |
| 24 | +use arrow::datatypes::DataType; |
| 25 | +use datafusion_common::cast::as_int64_array; |
| 26 | +use datafusion_common::DataFusionError; |
29 | 27 | use datafusion_common::{exec_err, Result};
|
30 | 28 | use datafusion_expr::TypeSignature::Exact;
|
31 | 29 | use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
|
| 30 | +use std::any::Any; |
| 31 | +use std::fmt::Write; |
| 32 | +use std::sync::Arc; |
| 33 | +use unicode_segmentation::UnicodeSegmentation; |
| 34 | +use DataType::{LargeUtf8, Utf8, Utf8View}; |
32 | 35 |
|
33 | 36 | #[derive(Debug)]
|
34 | 37 | pub struct RPadFunc {
|
@@ -84,170 +87,182 @@ impl ScalarUDFImpl for RPadFunc {
|
84 | 87 | }
|
85 | 88 |
|
86 | 89 | fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
|
87 |
| - match args.len() { |
88 |
| - 2 => match args[0].data_type() { |
89 |
| - DataType::Utf8 | DataType::Utf8View => { |
90 |
| - make_scalar_function(rpad::<i32, i32>, vec![])(args) |
91 |
| - } |
92 |
| - DataType::LargeUtf8 => { |
93 |
| - make_scalar_function(rpad::<i64, i64>, vec![])(args) |
94 |
| - } |
95 |
| - other => exec_err!("Unsupported data type {other:?} for function rpad"), |
96 |
| - }, |
97 |
| - 3 => match (args[0].data_type(), args[2].data_type()) { |
98 |
| - ( |
99 |
| - DataType::Utf8 | DataType::Utf8View, |
100 |
| - DataType::Utf8 | DataType::Utf8View, |
101 |
| - ) => make_scalar_function(rpad::<i32, i32>, vec![])(args), |
102 |
| - (DataType::LargeUtf8, DataType::LargeUtf8) => { |
103 |
| - make_scalar_function(rpad::<i64, i64>, vec![])(args) |
104 |
| - } |
105 |
| - (DataType::LargeUtf8, DataType::Utf8View | DataType::Utf8) => { |
106 |
| - make_scalar_function(rpad::<i64, i32>, vec![])(args) |
107 |
| - } |
108 |
| - (DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8) => { |
109 |
| - make_scalar_function(rpad::<i32, i64>, vec![])(args) |
110 |
| - } |
111 |
| - (first_type, last_type) => { |
112 |
| - exec_err!("unsupported arguments type for rpad, first argument type is {}, last argument type is {}", first_type, last_type) |
113 |
| - } |
114 |
| - }, |
115 |
| - number => { |
116 |
| - exec_err!("unsupported arguments number {} for rpad", number) |
| 90 | + match ( |
| 91 | + args.len(), |
| 92 | + args[0].data_type(), |
| 93 | + args.get(2).map(|arg| arg.data_type()), |
| 94 | + ) { |
| 95 | + (2, Utf8 | Utf8View, _) => { |
| 96 | + make_scalar_function(rpad::<i32, i32>, vec![])(args) |
| 97 | + } |
| 98 | + (2, LargeUtf8, _) => make_scalar_function(rpad::<i64, i64>, vec![])(args), |
| 99 | + (3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => { |
| 100 | + make_scalar_function(rpad::<i32, i32>, vec![])(args) |
| 101 | + } |
| 102 | + (3, LargeUtf8, Some(LargeUtf8)) => { |
| 103 | + make_scalar_function(rpad::<i64, i64>, vec![])(args) |
| 104 | + } |
| 105 | + (3, Utf8 | Utf8View, Some(LargeUtf8)) => { |
| 106 | + make_scalar_function(rpad::<i32, i64>, vec![])(args) |
| 107 | + } |
| 108 | + (3, LargeUtf8, Some(Utf8 | Utf8View)) => { |
| 109 | + make_scalar_function(rpad::<i64, i32>, vec![])(args) |
| 110 | + } |
| 111 | + (_, _, _) => { |
| 112 | + exec_err!("Unsupported combination of data types for function rpad") |
117 | 113 | }
|
118 | 114 | }
|
119 | 115 | }
|
120 | 116 | }
|
121 | 117 |
|
122 |
| -macro_rules! process_rpad { |
123 |
| - // For the two-argument case |
124 |
| - ($string_array:expr, $length_array:expr) => {{ |
125 |
| - $string_array |
126 |
| - .iter() |
127 |
| - .zip($length_array.iter()) |
128 |
| - .map(|(string, length)| match (string, length) { |
129 |
| - (Some(string), Some(length)) => { |
130 |
| - if length > i32::MAX as i64 { |
131 |
| - return exec_err!("rpad requested length {} too large", length); |
132 |
| - } |
133 |
| - |
134 |
| - let length = if length < 0 { 0 } else { length as usize }; |
135 |
| - if length == 0 { |
136 |
| - Ok(Some("".to_string())) |
137 |
| - } else { |
138 |
| - let graphemes = string.graphemes(true).collect::<Vec<&str>>(); |
139 |
| - if length < graphemes.len() { |
140 |
| - Ok(Some(graphemes[..length].concat())) |
141 |
| - } else { |
142 |
| - let mut s = string.to_string(); |
143 |
| - s.push_str(" ".repeat(length - graphemes.len()).as_str()); |
144 |
| - Ok(Some(s)) |
145 |
| - } |
146 |
| - } |
147 |
| - } |
148 |
| - _ => Ok(None), |
149 |
| - }) |
150 |
| - .collect::<Result<GenericStringArray<StringArrayLen>>>() |
151 |
| - }}; |
152 |
| - |
153 |
| - // For the three-argument case |
154 |
| - ($string_array:expr, $length_array:expr, $fill_array:expr) => {{ |
155 |
| - $string_array |
156 |
| - .iter() |
157 |
| - .zip($length_array.iter()) |
158 |
| - .zip($fill_array.iter()) |
159 |
| - .map(|((string, length), fill)| match (string, length, fill) { |
160 |
| - (Some(string), Some(length), Some(fill)) => { |
161 |
| - if length > i32::MAX as i64 { |
162 |
| - return exec_err!("rpad requested length {} too large", length); |
163 |
| - } |
164 |
| - |
165 |
| - let length = if length < 0 { 0 } else { length as usize }; |
166 |
| - let graphemes = string.graphemes(true).collect::<Vec<&str>>(); |
167 |
| - let fill_chars = fill.chars().collect::<Vec<char>>(); |
| 118 | +pub fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>( |
| 119 | + args: &[ArrayRef], |
| 120 | +) -> Result<ArrayRef> { |
| 121 | + if args.len() < 2 || args.len() > 3 { |
| 122 | + return exec_err!( |
| 123 | + "rpad was called with {} arguments. It requires 2 or 3 arguments.", |
| 124 | + args.len() |
| 125 | + ); |
| 126 | + } |
168 | 127 |
|
169 |
| - if length < graphemes.len() { |
170 |
| - Ok(Some(graphemes[..length].concat())) |
171 |
| - } else if fill_chars.is_empty() { |
172 |
| - Ok(Some(string.to_string())) |
173 |
| - } else { |
174 |
| - let mut s = string.to_string(); |
175 |
| - let char_vector: Vec<char> = (0..length - graphemes.len()) |
176 |
| - .map(|l| fill_chars[l % fill_chars.len()]) |
177 |
| - .collect(); |
178 |
| - s.push_str(&char_vector.iter().collect::<String>()); |
179 |
| - Ok(Some(s)) |
180 |
| - } |
181 |
| - } |
182 |
| - _ => Ok(None), |
183 |
| - }) |
184 |
| - .collect::<Result<GenericStringArray<StringArrayLen>>>() |
185 |
| - }}; |
| 128 | + let length_array = as_int64_array(&args[1])?; |
| 129 | + match ( |
| 130 | + args.len(), |
| 131 | + args[0].data_type(), |
| 132 | + args.get(2).map(|arg| arg.data_type()), |
| 133 | + ) { |
| 134 | + (2, Utf8View, _) => { |
| 135 | + rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>( |
| 136 | + args[0].as_string_view(), |
| 137 | + length_array, |
| 138 | + None, |
| 139 | + ) |
| 140 | + } |
| 141 | + (3, Utf8View, Some(Utf8View)) => { |
| 142 | + rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>( |
| 143 | + args[0].as_string_view(), |
| 144 | + length_array, |
| 145 | + Some(args[2].as_string_view()), |
| 146 | + ) |
| 147 | + } |
| 148 | + (3, Utf8View, Some(Utf8 | LargeUtf8)) => { |
| 149 | + rpad_impl::<&StringViewArray, &GenericStringArray<FillArrayLen>, StringArrayLen>( |
| 150 | + args[0].as_string_view(), |
| 151 | + length_array, |
| 152 | + Some(args[2].as_string::<FillArrayLen>()), |
| 153 | + ) |
| 154 | + } |
| 155 | + (3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::< |
| 156 | + &GenericStringArray<StringArrayLen>, |
| 157 | + &StringViewArray, |
| 158 | + StringArrayLen, |
| 159 | + >( |
| 160 | + args[0].as_string::<StringArrayLen>(), |
| 161 | + length_array, |
| 162 | + Some(args[2].as_string_view()), |
| 163 | + ), |
| 164 | + (_, _, _) => rpad_impl::< |
| 165 | + &GenericStringArray<StringArrayLen>, |
| 166 | + &GenericStringArray<FillArrayLen>, |
| 167 | + StringArrayLen, |
| 168 | + >( |
| 169 | + args[0].as_string::<StringArrayLen>(), |
| 170 | + length_array, |
| 171 | + args.get(2).map(|arg| arg.as_string::<FillArrayLen>()), |
| 172 | + ), |
| 173 | + } |
186 | 174 | }
|
187 | 175 |
|
188 | 176 | /// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
|
189 | 177 | /// rpad('hi', 5, 'xy') = 'hixyx'
|
190 |
| -pub fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>( |
191 |
| - args: &[ArrayRef], |
192 |
| -) -> Result<ArrayRef> { |
193 |
| - match (args.len(), args[0].data_type()) { |
194 |
| - (2, DataType::Utf8View) => { |
195 |
| - let string_array = as_string_view_array(&args[0])?; |
196 |
| - let length_array = as_int64_array(&args[1])?; |
| 178 | +pub fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>( |
| 179 | + string_array: StringArrType, |
| 180 | + length_array: &Int64Array, |
| 181 | + fill_array: Option<FillArrType>, |
| 182 | +) -> Result<ArrayRef> |
| 183 | +where |
| 184 | + StringArrType: StringArrayType<'a>, |
| 185 | + FillArrType: StringArrayType<'a>, |
| 186 | + StringArrayLen: OffsetSizeTrait, |
| 187 | +{ |
| 188 | + let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new(); |
197 | 189 |
|
198 |
| - let result = process_rpad!(string_array, length_array)?; |
199 |
| - Ok(Arc::new(result) as ArrayRef) |
| 190 | + match fill_array { |
| 191 | + None => { |
| 192 | + string_array.iter().zip(length_array.iter()).try_for_each( |
| 193 | + |(string, length)| -> Result<(), DataFusionError> { |
| 194 | + match (string, length) { |
| 195 | + (Some(string), Some(length)) => { |
| 196 | + if length > i32::MAX as i64 { |
| 197 | + return exec_err!( |
| 198 | + "rpad requested length {} too large", |
| 199 | + length |
| 200 | + ); |
| 201 | + } |
| 202 | + let length = if length < 0 { 0 } else { length as usize }; |
| 203 | + if length == 0 { |
| 204 | + builder.append_value(""); |
| 205 | + } else { |
| 206 | + let graphemes = |
| 207 | + string.graphemes(true).collect::<Vec<&str>>(); |
| 208 | + if length < graphemes.len() { |
| 209 | + builder.append_value(graphemes[..length].concat()); |
| 210 | + } else { |
| 211 | + builder.write_str(string)?; |
| 212 | + builder.write_str( |
| 213 | + &" ".repeat(length - graphemes.len()), |
| 214 | + )?; |
| 215 | + builder.append_value(""); |
| 216 | + } |
| 217 | + } |
| 218 | + } |
| 219 | + _ => builder.append_null(), |
| 220 | + } |
| 221 | + Ok(()) |
| 222 | + }, |
| 223 | + )?; |
200 | 224 | }
|
201 |
| - (2, _) => { |
202 |
| - let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?; |
203 |
| - let length_array = as_int64_array(&args[1])?; |
| 225 | + Some(fill_array) => { |
| 226 | + string_array |
| 227 | + .iter() |
| 228 | + .zip(length_array.iter()) |
| 229 | + .zip(fill_array.iter()) |
| 230 | + .try_for_each( |
| 231 | + |((string, length), fill)| -> Result<(), DataFusionError> { |
| 232 | + match (string, length, fill) { |
| 233 | + (Some(string), Some(length), Some(fill)) => { |
| 234 | + if length > i32::MAX as i64 { |
| 235 | + return exec_err!( |
| 236 | + "rpad requested length {} too large", |
| 237 | + length |
| 238 | + ); |
| 239 | + } |
| 240 | + let length = if length < 0 { 0 } else { length as usize }; |
| 241 | + let graphemes = |
| 242 | + string.graphemes(true).collect::<Vec<&str>>(); |
204 | 243 |
|
205 |
| - let result = process_rpad!(string_array, length_array)?; |
206 |
| - Ok(Arc::new(result) as ArrayRef) |
207 |
| - } |
208 |
| - (3, DataType::Utf8View) => { |
209 |
| - let string_array = as_string_view_array(&args[0])?; |
210 |
| - let length_array = as_int64_array(&args[1])?; |
211 |
| - match args[2].data_type() { |
212 |
| - DataType::Utf8View => { |
213 |
| - let fill_array = as_string_view_array(&args[2])?; |
214 |
| - let result = process_rpad!(string_array, length_array, fill_array)?; |
215 |
| - Ok(Arc::new(result) as ArrayRef) |
216 |
| - } |
217 |
| - DataType::Utf8 | DataType::LargeUtf8 => { |
218 |
| - let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?; |
219 |
| - let result = process_rpad!(string_array, length_array, fill_array)?; |
220 |
| - Ok(Arc::new(result) as ArrayRef) |
221 |
| - } |
222 |
| - other_type => { |
223 |
| - exec_err!("unsupported type for rpad's third operator: {}", other_type) |
224 |
| - } |
225 |
| - } |
226 |
| - } |
227 |
| - (3, _) => { |
228 |
| - let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?; |
229 |
| - let length_array = as_int64_array(&args[1])?; |
230 |
| - match args[2].data_type() { |
231 |
| - DataType::Utf8View => { |
232 |
| - let fill_array = as_string_view_array(&args[2])?; |
233 |
| - let result = process_rpad!(string_array, length_array, fill_array)?; |
234 |
| - Ok(Arc::new(result) as ArrayRef) |
235 |
| - } |
236 |
| - DataType::Utf8 | DataType::LargeUtf8 => { |
237 |
| - let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?; |
238 |
| - let result = process_rpad!(string_array, length_array, fill_array)?; |
239 |
| - Ok(Arc::new(result) as ArrayRef) |
240 |
| - } |
241 |
| - other_type => { |
242 |
| - exec_err!("unsupported type for rpad's third operator: {}", other_type) |
243 |
| - } |
244 |
| - } |
| 244 | + if length < graphemes.len() { |
| 245 | + builder.append_value(graphemes[..length].concat()); |
| 246 | + } else if fill.is_empty() { |
| 247 | + builder.append_value(string); |
| 248 | + } else { |
| 249 | + builder.write_str(string)?; |
| 250 | + fill.chars() |
| 251 | + .cycle() |
| 252 | + .take(length - graphemes.len()) |
| 253 | + .for_each(|ch| builder.write_char(ch).unwrap()); |
| 254 | + builder.append_value(""); |
| 255 | + } |
| 256 | + } |
| 257 | + _ => builder.append_null(), |
| 258 | + } |
| 259 | + Ok(()) |
| 260 | + }, |
| 261 | + )?; |
245 | 262 | }
|
246 |
| - (other, other_type) => exec_err!( |
247 |
| - "rpad requires 2 or 3 arguments with corresponding types, but got {}. number of arguments with {}", |
248 |
| - other, other_type |
249 |
| - ), |
250 | 263 | }
|
| 264 | + |
| 265 | + Ok(Arc::new(builder.finish()) as ArrayRef) |
251 | 266 | }
|
252 | 267 |
|
253 | 268 | #[cfg(test)]
|
|
0 commit comments