Skip to content

Commit 7eeac2f

Browse files
authored
Improve rpad udf by using a GenericStringBuilder (#12070)
* Improve rpad udf by using a GenericStringBuilder * fix format * refine code
1 parent 78f58c8 commit 7eeac2f

File tree

2 files changed

+180
-164
lines changed

2 files changed

+180
-164
lines changed

datafusion/functions/benches/pad.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,12 @@ fn criterion_benchmark(c: &mut Criterion) {
127127
group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| {
128128
b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap()))
129129
});
130-
//
131-
// let args = create_args::<i32>(size, 32, true);
132-
// group.bench_function(BenchmarkId::new("stringview type", size), |b| {
133-
// b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap()))
134-
// });
130+
131+
// rpad for stringview type
132+
let args = create_args::<i32>(size, 32, true);
133+
group.bench_function(BenchmarkId::new("stringview type", size), |b| {
134+
b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap()))
135+
});
135136

136137
group.finish();
137138
}

datafusion/functions/src/unicode/rpad.rs

+174-159
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,23 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::any::Any;
19-
use std::sync::Arc;
20-
21-
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
22-
use arrow::datatypes::DataType;
23-
use datafusion_common::cast::{
24-
as_generic_string_array, as_int64_array, as_string_view_array,
25-
};
26-
use unicode_segmentation::UnicodeSegmentation;
27-
18+
use crate::string::common::StringArrayType;
2819
use crate::utils::{make_scalar_function, utf8_to_str_type};
20+
use arrow::array::{
21+
ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
22+
OffsetSizeTrait, StringViewArray,
23+
};
24+
use arrow::datatypes::DataType;
25+
use datafusion_common::cast::as_int64_array;
26+
use datafusion_common::DataFusionError;
2927
use datafusion_common::{exec_err, Result};
3028
use datafusion_expr::TypeSignature::Exact;
3129
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
30+
use std::any::Any;
31+
use std::fmt::Write;
32+
use std::sync::Arc;
33+
use unicode_segmentation::UnicodeSegmentation;
34+
use DataType::{LargeUtf8, Utf8, Utf8View};
3235

3336
#[derive(Debug)]
3437
pub struct RPadFunc {
@@ -84,170 +87,182 @@ impl ScalarUDFImpl for RPadFunc {
8487
}
8588

8689
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
87-
match args.len() {
88-
2 => match args[0].data_type() {
89-
DataType::Utf8 | DataType::Utf8View => {
90-
make_scalar_function(rpad::<i32, i32>, vec![])(args)
91-
}
92-
DataType::LargeUtf8 => {
93-
make_scalar_function(rpad::<i64, i64>, vec![])(args)
94-
}
95-
other => exec_err!("Unsupported data type {other:?} for function rpad"),
96-
},
97-
3 => match (args[0].data_type(), args[2].data_type()) {
98-
(
99-
DataType::Utf8 | DataType::Utf8View,
100-
DataType::Utf8 | DataType::Utf8View,
101-
) => make_scalar_function(rpad::<i32, i32>, vec![])(args),
102-
(DataType::LargeUtf8, DataType::LargeUtf8) => {
103-
make_scalar_function(rpad::<i64, i64>, vec![])(args)
104-
}
105-
(DataType::LargeUtf8, DataType::Utf8View | DataType::Utf8) => {
106-
make_scalar_function(rpad::<i64, i32>, vec![])(args)
107-
}
108-
(DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8) => {
109-
make_scalar_function(rpad::<i32, i64>, vec![])(args)
110-
}
111-
(first_type, last_type) => {
112-
exec_err!("unsupported arguments type for rpad, first argument type is {}, last argument type is {}", first_type, last_type)
113-
}
114-
},
115-
number => {
116-
exec_err!("unsupported arguments number {} for rpad", number)
90+
match (
91+
args.len(),
92+
args[0].data_type(),
93+
args.get(2).map(|arg| arg.data_type()),
94+
) {
95+
(2, Utf8 | Utf8View, _) => {
96+
make_scalar_function(rpad::<i32, i32>, vec![])(args)
97+
}
98+
(2, LargeUtf8, _) => make_scalar_function(rpad::<i64, i64>, vec![])(args),
99+
(3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => {
100+
make_scalar_function(rpad::<i32, i32>, vec![])(args)
101+
}
102+
(3, LargeUtf8, Some(LargeUtf8)) => {
103+
make_scalar_function(rpad::<i64, i64>, vec![])(args)
104+
}
105+
(3, Utf8 | Utf8View, Some(LargeUtf8)) => {
106+
make_scalar_function(rpad::<i32, i64>, vec![])(args)
107+
}
108+
(3, LargeUtf8, Some(Utf8 | Utf8View)) => {
109+
make_scalar_function(rpad::<i64, i32>, vec![])(args)
110+
}
111+
(_, _, _) => {
112+
exec_err!("Unsupported combination of data types for function rpad")
117113
}
118114
}
119115
}
120116
}
121117

122-
macro_rules! process_rpad {
123-
// For the two-argument case
124-
($string_array:expr, $length_array:expr) => {{
125-
$string_array
126-
.iter()
127-
.zip($length_array.iter())
128-
.map(|(string, length)| match (string, length) {
129-
(Some(string), Some(length)) => {
130-
if length > i32::MAX as i64 {
131-
return exec_err!("rpad requested length {} too large", length);
132-
}
133-
134-
let length = if length < 0 { 0 } else { length as usize };
135-
if length == 0 {
136-
Ok(Some("".to_string()))
137-
} else {
138-
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
139-
if length < graphemes.len() {
140-
Ok(Some(graphemes[..length].concat()))
141-
} else {
142-
let mut s = string.to_string();
143-
s.push_str(" ".repeat(length - graphemes.len()).as_str());
144-
Ok(Some(s))
145-
}
146-
}
147-
}
148-
_ => Ok(None),
149-
})
150-
.collect::<Result<GenericStringArray<StringArrayLen>>>()
151-
}};
152-
153-
// For the three-argument case
154-
($string_array:expr, $length_array:expr, $fill_array:expr) => {{
155-
$string_array
156-
.iter()
157-
.zip($length_array.iter())
158-
.zip($fill_array.iter())
159-
.map(|((string, length), fill)| match (string, length, fill) {
160-
(Some(string), Some(length), Some(fill)) => {
161-
if length > i32::MAX as i64 {
162-
return exec_err!("rpad requested length {} too large", length);
163-
}
164-
165-
let length = if length < 0 { 0 } else { length as usize };
166-
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
167-
let fill_chars = fill.chars().collect::<Vec<char>>();
118+
pub fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
119+
args: &[ArrayRef],
120+
) -> Result<ArrayRef> {
121+
if args.len() < 2 || args.len() > 3 {
122+
return exec_err!(
123+
"rpad was called with {} arguments. It requires 2 or 3 arguments.",
124+
args.len()
125+
);
126+
}
168127

169-
if length < graphemes.len() {
170-
Ok(Some(graphemes[..length].concat()))
171-
} else if fill_chars.is_empty() {
172-
Ok(Some(string.to_string()))
173-
} else {
174-
let mut s = string.to_string();
175-
let char_vector: Vec<char> = (0..length - graphemes.len())
176-
.map(|l| fill_chars[l % fill_chars.len()])
177-
.collect();
178-
s.push_str(&char_vector.iter().collect::<String>());
179-
Ok(Some(s))
180-
}
181-
}
182-
_ => Ok(None),
183-
})
184-
.collect::<Result<GenericStringArray<StringArrayLen>>>()
185-
}};
128+
let length_array = as_int64_array(&args[1])?;
129+
match (
130+
args.len(),
131+
args[0].data_type(),
132+
args.get(2).map(|arg| arg.data_type()),
133+
) {
134+
(2, Utf8View, _) => {
135+
rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
136+
args[0].as_string_view(),
137+
length_array,
138+
None,
139+
)
140+
}
141+
(3, Utf8View, Some(Utf8View)) => {
142+
rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
143+
args[0].as_string_view(),
144+
length_array,
145+
Some(args[2].as_string_view()),
146+
)
147+
}
148+
(3, Utf8View, Some(Utf8 | LargeUtf8)) => {
149+
rpad_impl::<&StringViewArray, &GenericStringArray<FillArrayLen>, StringArrayLen>(
150+
args[0].as_string_view(),
151+
length_array,
152+
Some(args[2].as_string::<FillArrayLen>()),
153+
)
154+
}
155+
(3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::<
156+
&GenericStringArray<StringArrayLen>,
157+
&StringViewArray,
158+
StringArrayLen,
159+
>(
160+
args[0].as_string::<StringArrayLen>(),
161+
length_array,
162+
Some(args[2].as_string_view()),
163+
),
164+
(_, _, _) => rpad_impl::<
165+
&GenericStringArray<StringArrayLen>,
166+
&GenericStringArray<FillArrayLen>,
167+
StringArrayLen,
168+
>(
169+
args[0].as_string::<StringArrayLen>(),
170+
length_array,
171+
args.get(2).map(|arg| arg.as_string::<FillArrayLen>()),
172+
),
173+
}
186174
}
187175

188176
/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
189177
/// rpad('hi', 5, 'xy') = 'hixyx'
190-
pub fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
191-
args: &[ArrayRef],
192-
) -> Result<ArrayRef> {
193-
match (args.len(), args[0].data_type()) {
194-
(2, DataType::Utf8View) => {
195-
let string_array = as_string_view_array(&args[0])?;
196-
let length_array = as_int64_array(&args[1])?;
178+
pub fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>(
179+
string_array: StringArrType,
180+
length_array: &Int64Array,
181+
fill_array: Option<FillArrType>,
182+
) -> Result<ArrayRef>
183+
where
184+
StringArrType: StringArrayType<'a>,
185+
FillArrType: StringArrayType<'a>,
186+
StringArrayLen: OffsetSizeTrait,
187+
{
188+
let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();
197189

198-
let result = process_rpad!(string_array, length_array)?;
199-
Ok(Arc::new(result) as ArrayRef)
190+
match fill_array {
191+
None => {
192+
string_array.iter().zip(length_array.iter()).try_for_each(
193+
|(string, length)| -> Result<(), DataFusionError> {
194+
match (string, length) {
195+
(Some(string), Some(length)) => {
196+
if length > i32::MAX as i64 {
197+
return exec_err!(
198+
"rpad requested length {} too large",
199+
length
200+
);
201+
}
202+
let length = if length < 0 { 0 } else { length as usize };
203+
if length == 0 {
204+
builder.append_value("");
205+
} else {
206+
let graphemes =
207+
string.graphemes(true).collect::<Vec<&str>>();
208+
if length < graphemes.len() {
209+
builder.append_value(graphemes[..length].concat());
210+
} else {
211+
builder.write_str(string)?;
212+
builder.write_str(
213+
&" ".repeat(length - graphemes.len()),
214+
)?;
215+
builder.append_value("");
216+
}
217+
}
218+
}
219+
_ => builder.append_null(),
220+
}
221+
Ok(())
222+
},
223+
)?;
200224
}
201-
(2, _) => {
202-
let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?;
203-
let length_array = as_int64_array(&args[1])?;
225+
Some(fill_array) => {
226+
string_array
227+
.iter()
228+
.zip(length_array.iter())
229+
.zip(fill_array.iter())
230+
.try_for_each(
231+
|((string, length), fill)| -> Result<(), DataFusionError> {
232+
match (string, length, fill) {
233+
(Some(string), Some(length), Some(fill)) => {
234+
if length > i32::MAX as i64 {
235+
return exec_err!(
236+
"rpad requested length {} too large",
237+
length
238+
);
239+
}
240+
let length = if length < 0 { 0 } else { length as usize };
241+
let graphemes =
242+
string.graphemes(true).collect::<Vec<&str>>();
204243

205-
let result = process_rpad!(string_array, length_array)?;
206-
Ok(Arc::new(result) as ArrayRef)
207-
}
208-
(3, DataType::Utf8View) => {
209-
let string_array = as_string_view_array(&args[0])?;
210-
let length_array = as_int64_array(&args[1])?;
211-
match args[2].data_type() {
212-
DataType::Utf8View => {
213-
let fill_array = as_string_view_array(&args[2])?;
214-
let result = process_rpad!(string_array, length_array, fill_array)?;
215-
Ok(Arc::new(result) as ArrayRef)
216-
}
217-
DataType::Utf8 | DataType::LargeUtf8 => {
218-
let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?;
219-
let result = process_rpad!(string_array, length_array, fill_array)?;
220-
Ok(Arc::new(result) as ArrayRef)
221-
}
222-
other_type => {
223-
exec_err!("unsupported type for rpad's third operator: {}", other_type)
224-
}
225-
}
226-
}
227-
(3, _) => {
228-
let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?;
229-
let length_array = as_int64_array(&args[1])?;
230-
match args[2].data_type() {
231-
DataType::Utf8View => {
232-
let fill_array = as_string_view_array(&args[2])?;
233-
let result = process_rpad!(string_array, length_array, fill_array)?;
234-
Ok(Arc::new(result) as ArrayRef)
235-
}
236-
DataType::Utf8 | DataType::LargeUtf8 => {
237-
let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?;
238-
let result = process_rpad!(string_array, length_array, fill_array)?;
239-
Ok(Arc::new(result) as ArrayRef)
240-
}
241-
other_type => {
242-
exec_err!("unsupported type for rpad's third operator: {}", other_type)
243-
}
244-
}
244+
if length < graphemes.len() {
245+
builder.append_value(graphemes[..length].concat());
246+
} else if fill.is_empty() {
247+
builder.append_value(string);
248+
} else {
249+
builder.write_str(string)?;
250+
fill.chars()
251+
.cycle()
252+
.take(length - graphemes.len())
253+
.for_each(|ch| builder.write_char(ch).unwrap());
254+
builder.append_value("");
255+
}
256+
}
257+
_ => builder.append_null(),
258+
}
259+
Ok(())
260+
},
261+
)?;
245262
}
246-
(other, other_type) => exec_err!(
247-
"rpad requires 2 or 3 arguments with corresponding types, but got {}. number of arguments with {}",
248-
other, other_type
249-
),
250263
}
264+
265+
Ok(Arc::new(builder.finish()) as ArrayRef)
251266
}
252267

253268
#[cfg(test)]

0 commit comments

Comments
 (0)