Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf: Support Utf8View datatype single column comparisons for SortPreservingMergeStream #15348

Merged
merged 5 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion benchmarks/src/sort_tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl RunOpt {
/// Payload Columns:
/// - Thin variant: `l_partkey` column with `BIGINT` type (1 column)
/// - Wide variant: all columns except for possible key columns (12 columns)
const SORT_QUERIES: [&'static str; 10] = [
const SORT_QUERIES: [&'static str; 11] = [
// Q1: 1 sort key (type: INTEGER, cardinality: 7) + 1 payload column
r#"
SELECT l_linenumber, l_partkey
Expand Down Expand Up @@ -159,6 +159,12 @@ impl RunOpt {
FROM lineitem
ORDER BY l_orderkey, l_suppkey, l_linenumber, l_comment
"#,
// Q11: 1 sort key (type: VARCHAR, cardinality: 4.5M) + 1 payload column
r#"
SELECT l_shipmode, l_comment, l_partkey
FROM lineitem
ORDER BY l_shipmode;
"#,
];

/// If query is specified from command line, run only that query.
Expand Down
111 changes: 109 additions & 2 deletions datafusion/core/benches/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@

use std::sync::Arc;

use arrow::array::StringViewArray;
use arrow::{
array::{DictionaryArray, Float64Array, Int64Array, StringArray},
compute::SortOptions,
datatypes::{Int32Type, Schema},
record_batch::RecordBatch,
};

use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::{
execution::context::TaskContext,
Expand Down Expand Up @@ -114,11 +114,24 @@ fn criterion_benchmark(c: &mut Criterion) {
("f64", &f64_streams),
("utf8 low cardinality", &utf8_low_cardinality_streams),
("utf8 high cardinality", &utf8_high_cardinality_streams),
(
"utf8 view low cardinality",
&utf8_view_low_cardinality_streams,
),
(
"utf8 view high cardinality",
&utf8_view_high_cardinality_streams,
),
("utf8 tuple", &utf8_tuple_streams),
("utf8 view tuple", &utf8_view_tuple_streams),
("utf8 dictionary", &dictionary_streams),
("utf8 dictionary tuple", &dictionary_tuple_streams),
("mixed dictionary tuple", &mixed_dictionary_tuple_streams),
("mixed tuple", &mixed_tuple_streams),
(
"mixed tuple with utf8 view",
&mixed_tuple_with_utf8_view_streams,
),
];

for (name, f) in cases {
Expand Down Expand Up @@ -308,6 +321,30 @@ fn utf8_low_cardinality_streams(sorted: bool) -> PartitionedBatches {
})
}

/// Create streams of random low cardinality utf8_view values
fn utf8_view_low_cardinality_streams(sorted: bool) -> PartitionedBatches {
let mut values = DataGenerator::new().utf8_low_cardinality_values();
if sorted {
values.sort_unstable();
}
split_tuples(values, |v| {
let array: StringViewArray = v.into_iter().collect();
RecordBatch::try_from_iter(vec![("utf_view_low", Arc::new(array) as _)]).unwrap()
})
}

/// Create streams of high cardinality (~ no duplicates) utf8_view values
fn utf8_view_high_cardinality_streams(sorted: bool) -> PartitionedBatches {
let mut values = DataGenerator::new().utf8_high_cardinality_values();
if sorted {
values.sort_unstable();
}
split_tuples(values, |v| {
let array: StringViewArray = v.into_iter().collect();
RecordBatch::try_from_iter(vec![("utf_view_high", Arc::new(array) as _)]).unwrap()
})
}

/// Create streams of high cardinality (~ no duplicates) utf8 values
fn utf8_high_cardinality_streams(sorted: bool) -> PartitionedBatches {
let mut values = DataGenerator::new().utf8_high_cardinality_values();
Expand Down Expand Up @@ -353,6 +390,39 @@ fn utf8_tuple_streams(sorted: bool) -> PartitionedBatches {
})
}

/// Create a batch of (utf8_view_low, utf8_view_low, utf8_view_high)
fn utf8_view_tuple_streams(sorted: bool) -> PartitionedBatches {
let mut gen = DataGenerator::new();

// need to sort by the combined key, so combine them together
let mut tuples: Vec<_> = gen
.utf8_low_cardinality_values()
.into_iter()
.zip(gen.utf8_low_cardinality_values())
.zip(gen.utf8_high_cardinality_values())
.collect();

if sorted {
tuples.sort_unstable();
}

split_tuples(tuples, |tuples| {
let (tuples, utf8_high): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
let (utf8_low1, utf8_low2): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();

let utf8_view_high: StringViewArray = utf8_high.into_iter().collect();
let utf8_view_low1: StringViewArray = utf8_low1.into_iter().collect();
let utf8_view_low2: StringViewArray = utf8_low2.into_iter().collect();

RecordBatch::try_from_iter(vec![
("utf_view_low1", Arc::new(utf8_view_low1) as _),
("utf_view_low2", Arc::new(utf8_view_low2) as _),
("utf_view_high", Arc::new(utf8_view_high) as _),
])
.unwrap()
})
}

/// Create a batch of (f64, utf8_low, utf8_low, i64)
fn mixed_tuple_streams(sorted: bool) -> PartitionedBatches {
let mut gen = DataGenerator::new();
Expand Down Expand Up @@ -391,6 +461,44 @@ fn mixed_tuple_streams(sorted: bool) -> PartitionedBatches {
})
}

/// Create a batch of (f64, utf8_view_low, utf8_view_low, i64)
fn mixed_tuple_with_utf8_view_streams(sorted: bool) -> PartitionedBatches {
let mut gen = DataGenerator::new();

// need to sort by the combined key, so combine them together
let mut tuples: Vec<_> = gen
.i64_values()
.into_iter()
.zip(gen.utf8_low_cardinality_values())
.zip(gen.utf8_low_cardinality_values())
.zip(gen.i64_values())
.collect();

if sorted {
tuples.sort_unstable();
}

split_tuples(tuples, |tuples| {
let (tuples, i64_values): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
let (tuples, utf8_low2): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
let (f64_values, utf8_low1): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();

let f64_values: Float64Array = f64_values.into_iter().map(|v| v as f64).collect();

let utf8_view_low1: StringViewArray = utf8_low1.into_iter().collect();
let utf8_view_low2: StringViewArray = utf8_low2.into_iter().collect();
let i64_values: Int64Array = i64_values.into_iter().collect();

RecordBatch::try_from_iter(vec![
("f64", Arc::new(f64_values) as _),
("utf_view_low1", Arc::new(utf8_view_low1) as _),
("utf_view_low2", Arc::new(utf8_view_low2) as _),
("i64", Arc::new(i64_values) as _),
])
.unwrap()
})
}

/// Create a batch of (utf8_dict)
fn dictionary_streams(sorted: bool) -> PartitionedBatches {
let mut gen = DataGenerator::new();
Expand All @@ -402,7 +510,6 @@ fn dictionary_streams(sorted: bool) -> PartitionedBatches {
split_tuples(values, |v| {
let dictionary: DictionaryArray<Int32Type> =
v.iter().map(Option::as_deref).collect();

RecordBatch::try_from_iter(vec![("dict", Arc::new(dictionary) as _)]).unwrap()
})
}
Expand Down
57 changes: 55 additions & 2 deletions datafusion/physical-plan/src/sorts/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
use std::cmp::Ordering;

use arrow::array::{
types::ByteArrayType, Array, ArrowPrimitiveType, GenericByteArray, OffsetSizeTrait,
PrimitiveArray,
types::ByteArrayType, Array, ArrowPrimitiveType, GenericByteArray,
GenericByteViewArray, OffsetSizeTrait, PrimitiveArray, StringViewArray,
};
use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer};
use arrow::compute::SortOptions;
Expand Down Expand Up @@ -281,6 +281,59 @@ impl<T: ByteArrayType> CursorArray for GenericByteArray<T> {
}
}

impl CursorArray for StringViewArray {
type Values = StringViewArray;
fn values(&self) -> Self {
self.clone()
}
}

impl CursorValues for StringViewArray {
fn len(&self) -> usize {
self.views().len()
}

fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
// SAFETY: Both l_idx and r_idx are guaranteed to be within bounds,
// and any null-checks are handled in the outer layers.
// Fast path: Compare the lengths before full byte comparison.

let l_view = unsafe { l.views().get_unchecked(l_idx) };
let l_len = *l_view as u32;
let r_view = unsafe { r.views().get_unchecked(r_idx) };
let r_len = *r_view as u32;
if l_len != r_len {
return false;
}

unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx).is_eq() }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a 'safety:' note to say why is is ok to use unsafe here. An example

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Omega359 for review, good example, i will address it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it would be good to justify the use of unchecked (which I think is ok here)

The docs say https://docs.rs/arrow/latest/arrow/array/struct.GenericByteViewArray.html#method.compare_unchecked

SO maybe the safety argument is mostly "The left/right_idx must within range of each array"

It also seems like we need to be comparing the Null masks too 🤔 like checking if the values are null before comparing

Given that this comparison is typically the hottest part of a merge operation maybe we should try using unchecked comparisions elswhere

}

fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
// SAFETY: The caller guarantees that idx > 0 and the indices are valid.
// Already checked it in is_eq_to_prev_one function
// Fast path: Compare the lengths of the current and previous views.
let l_view = unsafe { cursor.views().get_unchecked(idx) };
let l_len = *l_view as u32;
let r_view = unsafe { cursor.views().get_unchecked(idx - 1) };
let r_len = *r_view as u32;
if l_len != r_len {
return false;
}

unsafe {
GenericByteViewArray::compare_unchecked(cursor, idx, cursor, idx - 1).is_eq()
}
}

fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
// SAFETY: Prior assertions guarantee that l_idx and r_idx are valid indices.
// Null-checks are assumed to have been handled in the wrapper (e.g., ArrayValues).
// And the bound is checked in is_finished, it is safe to call get_unchecked
unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx) }
}
}

/// A collection of sorted, nullable [`CursorValues`]
///
/// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-plan/src/sorts/streaming_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ impl<'a> StreamingMergeBuilder<'a> {
downcast_primitive! {
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::Utf8View => merge_helper!(StringViewArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
Expand Down
Loading