From d46dc75a02cabbc1323916d16c07c93427949133 Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Sun, 21 Aug 2022 19:52:20 +0200 Subject: [PATCH 1/5] Improve performance of stable sort This reworks the internals of slice::sort. Mainly: - Introduce branchless swap_next_if and optimized sortX functions - Speedup core batch extension with sort16 - Many small tweaks to reduce the amount of branches/jumps This commit is incomplete and MUST NOT be merged as is. It is missing Copy detection and would break uniqueness preservation of values that are being sorted. --- library/alloc/src/slice.rs | 450 +++++++++++++++++++++++++++++++++---- 1 file changed, 405 insertions(+), 45 deletions(-) diff --git a/library/alloc/src/slice.rs b/library/alloc/src/slice.rs index a5e7bf2a1a9f3..cd17cc9bc41d7 100644 --- a/library/alloc/src/slice.rs +++ b/library/alloc/src/slice.rs @@ -16,7 +16,9 @@ use core::borrow::{Borrow, BorrowMut}; #[cfg(not(no_global_oom_handling))] use core::cmp::Ordering::{self, Less}; #[cfg(not(no_global_oom_handling))] -use core::mem::{self, SizedTypeProperties}; +use core::mem; +#[cfg(not(no_global_oom_handling))] +use core::mem::size_of; #[cfg(not(no_global_oom_handling))] use core::ptr; @@ -203,7 +205,7 @@ impl [T] { where T: Ord, { - merge_sort(self, T::lt); + merge_sort(self, |a, b| a.lt(b)); } /// Sorts the slice with a comparator function. @@ -259,7 +261,7 @@ impl [T] { where F: FnMut(&T, &T) -> Ordering, { - merge_sort(self, |a, b| compare(a, b) == Less); + stable_sort(self, |a, b| compare(a, b) == Less); } /// Sorts the slice with a key extraction function. @@ -302,7 +304,7 @@ impl [T] { F: FnMut(&T) -> K, K: Ord, { - merge_sort(self, |a, b| f(a).lt(&f(b))); + stable_sort(self, |a, b| f(a).lt(&f(b))); } /// Sorts the slice with a key extraction function. @@ -809,16 +811,333 @@ impl ToOwned for [T] { // Sorting //////////////////////////////////////////////////////////////////////////////// +#[inline] +#[cfg(not(no_global_oom_handling))] +fn stable_sort(v: &mut [T], mut is_less: F) +where + F: FnMut(&T, &T) -> bool, +{ + if mem::size_of::() == 0 { + // Sorting has no meaningful behavior on zero-sized types. Do nothing. + return; + } + + merge_sort(v, &mut is_less); +} + +// Slices of up to this length get sorted using insertion sort. +const MAX_INSERTION: usize = 20; + +// Sort a small number of elements as fast as possible, without allocations. +#[inline] +#[cfg(not(no_global_oom_handling))] +fn sort_small(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + if len < 2 { + return; + } + + if T::is_copy() { + // SAFETY: We check the corresponding min len for sortX. + unsafe { + if len == 2 { + sort2(v, is_less); + } else if len == 3 { + sort3(v, is_less); + } else if len < 8 { + sort4(&mut v[..4], is_less); + insertion_sort_remaining(v, 4, is_less); + } else if len < 16 { + sort8(&mut v[..8], is_less); + insertion_sort_remaining(v, 8, is_less); + } else { + sort16(&mut v[..16], is_less); + insertion_sort_remaining(v, 16, is_less); + } + } + } else { + for i in (0..len - 1).rev() { + // SAFETY: We checked above that len is at least 2. + unsafe { + insert_head(&mut v[i..], is_less); + } + } + } +} + +#[cfg(not(no_global_oom_handling))] +fn merge_sort(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // Sorting has no meaningful behavior on zero-sized types. + if mem::size_of::() == 0 { + return; + } + + let len = v.len(); + + // Short arrays get sorted in-place via insertion sort to avoid allocations. + if len <= MAX_INSERTION { + sort_small(v, is_less); + return; + } + + // Allocate a buffer to use as scratch memory. We keep the length 0 so we can keep in it + // shallow copies of the contents of `v` without risking the dtors running on copies if + // `is_less` panics. When merging two sorted runs, this buffer holds a copy of the shorter run, + // which will always have length at most `len / 2`. + let mut buf = Vec::with_capacity(len / 2); + + // In order to identify natural runs in `v`, we traverse it backwards. That might seem like a + // strange decision, but consider the fact that merges more often go in the opposite direction + // (forwards). According to benchmarks, merging forwards is slightly faster than merging + // backwards. To conclude, identifying runs by traversing backwards improves performance. + let mut runs = vec![]; + let mut end = len; + while end > 0 { + // Find the next natural run, and reverse it if it's strictly descending. + let mut start = end - 1; + if start > 0 { + start -= 1; + unsafe { + if is_less(v.get_unchecked(start + 1), v.get_unchecked(start)) { + while start > 0 && is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) { + start -= 1; + } + v[start..end].reverse(); + } else { + while start > 0 && !is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) + { + start -= 1; + } + } + } + } + + // SAFETY: end > start. + start = provide_sorted_batch(v, start, end, is_less); + + // Push this run onto the stack. + runs.push(Run { start, len: end - start }); + end = start; + + // Merge some pairs of adjacent runs to satisfy the invariants. + while let Some(r) = collapse(&runs) { + let left = runs[r + 1]; + let right = runs[r]; + unsafe { + merge( + &mut v[left.start..right.start + right.len], + left.len, + buf.as_mut_ptr(), + is_less, + ); + } + runs[r] = Run { start: left.start, len: left.len + right.len }; + runs.remove(r + 1); + } + } + + // Finally, exactly one run must remain in the stack. + debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len); + + // Examines the stack of runs and identifies the next pair of runs to merge. More specifically, + // if `Some(r)` is returned, that means `runs[r]` and `runs[r + 1]` must be merged next. If the + // algorithm should continue building a new run instead, `None` is returned. + // + // TimSort is infamous for its buggy implementations, as described here: + // http://envisage-project.eu/timsort-specification-and-verification/ + // + // The gist of the story is: we must enforce the invariants on the top four runs on the stack. + // Enforcing them on just top three is not sufficient to ensure that the invariants will still + // hold for *all* runs in the stack. + // + // This function correctly checks invariants for the top four runs. Additionally, if the top + // run starts at index 0, it will always demand a merge operation until the stack is fully + // collapsed, in order to complete the sort. + #[inline] + fn collapse(runs: &[Run]) -> Option { + let n = runs.len(); + if n >= 2 + && (runs[n - 1].start == 0 + || runs[n - 2].len <= runs[n - 1].len + || (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len) + || (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len)) + { + if n >= 3 && runs[n - 3].len < runs[n - 1].len { + Some(n - 3) + } else { + Some(n - 2) + } + } else { + None + } + } + + #[derive(Clone, Copy)] + struct Run { + len: usize, + start: usize, + } +} + +/// Takes a range as denoted by start and end, that is already sorted and extends it if necessary +/// with sorts optimized for smaller ranges such as insertion sort. +#[cfg(not(no_global_oom_handling))] +fn provide_sorted_batch(v: &mut [T], mut start: usize, end: usize, is_less: &mut F) -> usize +where + F: FnMut(&T, &T) -> bool, +{ + // Not doing so is a logic bug, but not a safety bug. + debug_assert!(end > start); + + const MAX_PRE_SORT16: usize = 8; + + // Testing showed that using MAX_INSERTION here yields the best performance for many types, but + // incurs more total comparisons. A balance between least comparisons and best performance, as + // influenced by for example cache locality. + const MIN_INSERTION_RUN: usize = 10; + + // Insert some more elements into the run if it's too short. Insertion sort is faster than + // merge sort on short sequences, so this significantly improves performance. + let start_found = start; + let start_end_diff = end - start; + + if T::is_copy() && start_end_diff < MAX_PRE_SORT16 && start_found >= 16 { + // SAFETY: We just checked that start_found is >= 16. + unsafe { + start = start_found.unchecked_sub(16); + sort16(&mut v[start..start_found], is_less); + } + insertion_sort_remaining(&mut v[start..end], 16, is_less); + } else if start_end_diff < MIN_INSERTION_RUN { + start = start.saturating_sub(MIN_INSERTION_RUN - start_end_diff); + + for i in (start..start_found).rev() { + // SAFETY: We ensured that the slice length is always at lest 2 long. + // We know that start_found will be at least one less than end, + // and the range is exclusive. Which gives us i always <= (end - 2). + unsafe { + insert_head(&mut v[i..end], is_less); + } + } + } + + start +} + +// When dropped, copies from `src` into `dest`. +struct InsertionHole { + src: *const T, + dest: *mut T, +} + +impl Drop for InsertionHole { + fn drop(&mut self) { + // SAFETY: caller must ensure src is valid to read and dest is valid to write. They must not + // alias. + unsafe { + ptr::copy_nonoverlapping(self.src, self.dest, 1); + } + } +} + +/// Sort v assuming v[..offset] is already sorted. +#[inline] +#[cfg(not(no_global_oom_handling))] +fn insertion_sort_remaining(v: &mut [T], offset: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + // This is a logic but not a safety bug. + debug_assert!(offset != 0 && offset <= len); + + if len < 2 || offset == 0 { + return; + } + + // Shift each element of the unsorted region v[i..] as far left as is needed to make v sorted. + for i in offset..len { + insert_tail(&mut v[..=i], is_less); + } +} + +/// Inserts `v[v.len() - 1]` into pre-sorted sequence `v[..v.len() - 1]` so that whole `v[..]` +/// becomes sorted. +#[inline] +#[cfg(not(no_global_oom_handling))] +fn insert_tail(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + debug_assert!(v.len() >= 2); + + let arr_ptr = v.as_mut_ptr(); + let i = v.len() - 1; + + unsafe { + // See insert_head which talks about why this approach is beneficial. + let i_ptr = arr_ptr.add(i); + + // It's important that we use i_ptr here. If this check is positive and we continue, + // We want to make sure that no other copy of the value was seen by is_less. + // Otherwise we would have to copy it back. + if !is_less(&*i_ptr, &*i_ptr.sub(1)) { + return; + } + + // It's important, that we use tmp for comparison from now on. As it is the value that + // will be copied back. And notionally we could have created a divergence if we copy + // back the wrong value. + let tmp = mem::ManuallyDrop::new(ptr::read(i_ptr)); + // Intermediate state of the insertion process is always tracked by `hole`, which + // serves two purposes: + // 1. Protects integrity of `v` from panics in `is_less`. + // 2. Fills the remaining hole in `v` in the end. + // + // Panic safety: + // + // If `is_less` panics at any point during the process, `hole` will get dropped and + // fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it + // initially held exactly once. + let mut hole = InsertionHole { src: &*tmp, dest: i_ptr.sub(1) }; + ptr::copy_nonoverlapping(hole.dest, i_ptr, 1); + + // SAFETY: We know i is at least 1. + for j in (0..(i - 1)).rev() { + let j_ptr = arr_ptr.add(j); + if !is_less(&*tmp, &*j_ptr) { + break; + } + + hole.dest = j_ptr; + ptr::copy_nonoverlapping(hole.dest, j_ptr.add(1), 1); + } + // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. + } +} + /// Inserts `v[0]` into pre-sorted sequence `v[1..]` so that whole `v[..]` becomes sorted. /// /// This is the integral subroutine of insertion sort. +#[inline] #[cfg(not(no_global_oom_handling))] -fn insert_head(v: &mut [T], is_less: &mut F) +unsafe fn insert_head(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, { - if v.len() >= 2 && is_less(&v[1], &v[0]) { - unsafe { + debug_assert!(v.len() >= 2); + + // SAFETY: caller must ensure v is at least len 2. + unsafe { + if is_less(&v[1], &v[0]) { // There are three ways to implement insertion here: // // 1. Swap adjacent elements until the first one gets to its final destination. @@ -861,20 +1180,6 @@ where // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. } } - - // When dropped, copies from `src` into `dest`. - struct InsertionHole { - src: *const T, - dest: *mut T, - } - - impl Drop for InsertionHole { - fn drop(&mut self) { - unsafe { - ptr::copy_nonoverlapping(self.src, self.dest, 1); - } - } - } } /// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and @@ -890,8 +1195,8 @@ where F: FnMut(&T, &T) -> bool, { let len = v.len(); - let v = v.as_mut_ptr(); - let (v_mid, v_end) = unsafe { (v.add(mid), v.add(len)) }; + let arr_ptr = v.as_mut_ptr(); + let (v_mid, v_end) = unsafe { (arr_ptr.add(mid), arr_ptr.add(len)) }; // The merge process first copies the shorter run into `buf`. Then it traces the newly copied // run and the longer run forwards (or backwards), comparing their next unconsumed elements and @@ -915,8 +1220,8 @@ where if mid <= len - mid { // The left run is shorter. unsafe { - ptr::copy_nonoverlapping(v, buf, mid); - hole = MergeHole { start: buf, end: buf.add(mid), dest: v }; + ptr::copy_nonoverlapping(arr_ptr, buf, mid); + hole = MergeHole { start: buf, end: buf.add(mid), dest: arr_ptr }; } // Initially, these pointers point to the beginnings of their arrays. @@ -948,7 +1253,7 @@ where let right = &mut hole.end; let mut out = v_end; - while v < *left && buf < *right { + while arr_ptr < *left && buf < *right { // Consume the greater side. // If equal, prefer the right run to maintain stability. unsafe { @@ -993,20 +1298,62 @@ where } } -/// This merge sort borrows some (but not all) ideas from TimSort, which is described in detail -/// [here](https://github.com/python/cpython/blob/main/Objects/listsort.txt). -/// -/// The algorithm identifies strictly descending and non-descending subsequences, which are called -/// natural runs. There is a stack of pending runs yet to be merged. Each newly found run is pushed -/// onto the stack, and then some pairs of adjacent runs are merged until these two invariants are -/// satisfied: -/// -/// 1. for every `i` in `1..runs.len()`: `runs[i - 1].len > runs[i].len` -/// 2. for every `i` in `2..runs.len()`: `runs[i - 2].len > runs[i - 1].len + runs[i].len` -/// -/// The invariants ensure that the total running time is *O*(*n* \* log(*n*)) worst-case. -#[cfg(not(no_global_oom_handling))] -fn merge_sort(v: &mut [T], mut is_less: F) +trait IsCopy { + fn is_copy() -> bool; +} + +impl IsCopy for T { + fn is_copy() -> bool { + // FIXME, heuristic loss and uniqueness preservation bug. + true + } +} + +// FIXME! +// impl IsCopy for T { +// fn is_copy() -> bool { +// true +// } +// } + +// --- Branchless sorting (less branches not zero) --- + +/// Swap value with next value in array pointed to by arr_ptr if should_swap is true. +#[inline] +unsafe fn swap_next_if(arr_ptr: *mut T, should_swap: bool) { + // This is a branchless version of swap if. + // The equivalent code with a branch would be: + // + // if should_swap { + // ptr::swap_nonoverlapping(arr_ptr, arr_ptr.add(1), 1) + // } + // + // Be mindful in your benchmarking that this only starts to outperform branching code if the + // benchmark doesn't execute the same branches again and again. + // } + // + + // Give ourselves some scratch space to work with. + // We do not have to worry about drops: `MaybeUninit` does nothing when dropped. + let mut tmp = mem::MaybeUninit::::uninit(); + + // Perform the conditional swap. + // SAFETY: the caller must guarantee that `arr_ptr` and `arr_ptr.add(1)` are + // valid for writes and properly aligned. `tmp` cannot be overlapping either `arr_ptr` or + // `arr_ptr.add(1) because `tmp` was just allocated on the stack as a separate allocated object. + // And `arr_ptr` and `arr_ptr.add(1)` can't overlap either. + // However `arr_ptr` and `arr_ptr.add(should_swap as usize)` can point to the same memory if + // should_swap is false. + unsafe { + ptr::copy_nonoverlapping(arr_ptr.add(!should_swap as usize), tmp.as_mut_ptr(), 1); + ptr::copy(arr_ptr.add(should_swap as usize), arr_ptr, 1); + ptr::copy_nonoverlapping(tmp.as_ptr(), arr_ptr.add(1), 1); + } +} + +/// Swap value with next value in array pointed to by arr_ptr if should_swap is true. +#[inline] +unsafe fn swap_next_if_less(arr_ptr: *mut T, is_less: &mut F) where F: FnMut(&T, &T) -> bool, { @@ -1016,7 +1363,7 @@ where const MIN_RUN: usize = 10; // Sorting has no meaningful behavior on zero-sized types. - if T::IS_ZST { + if size_of::() == 0 { return; } @@ -1124,9 +1471,22 @@ where } } - #[derive(Clone, Copy)] - struct Run { - start: usize, - len: usize, + let mut swap = mem::MaybeUninit::<[T; 16]>::uninit(); + let swap_ptr = swap.as_mut_ptr() as *mut T; + + // Merge the already sorted v[0..4] with v[4..8] into swap. + parity_merge8(arr_ptr, swap_ptr, is_less); + // Merge the already sorted v[8..12] with v[12..16] into swap. + parity_merge8(arr_ptr.add(8), swap_ptr.add(8), is_less); + + // v is still the same as before parity_merge8 + // swap now contains a shallow copy of v and is sorted in v[0..8] and v[8..16] + // Merge the two partitions. + // parity_merge(swap_ptr, arr_ptr, 16, is_less); + + // FIXME discuss perf loss by promising original elements in case of panic. + ptr::copy_nonoverlapping(swap_ptr, arr_ptr, 16); + parity_merge(arr_ptr, swap_ptr, 16, is_less); + ptr::copy_nonoverlapping(swap_ptr, arr_ptr, 16); } } From d5f748e338c6901d4ad89cf0a17ac644e4aababd Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Tue, 23 Aug 2022 20:14:46 +0200 Subject: [PATCH 2/5] Fix formatting --- library/alloc/src/slice.rs | 458 ++++++++++++++++++++++++++++--------- 1 file changed, 351 insertions(+), 107 deletions(-) diff --git a/library/alloc/src/slice.rs b/library/alloc/src/slice.rs index cd17cc9bc41d7..7202b69131361 100644 --- a/library/alloc/src/slice.rs +++ b/library/alloc/src/slice.rs @@ -18,8 +18,6 @@ use core::cmp::Ordering::{self, Less}; #[cfg(not(no_global_oom_handling))] use core::mem; #[cfg(not(no_global_oom_handling))] -use core::mem::size_of; -#[cfg(not(no_global_oom_handling))] use core::ptr; use crate::alloc::Allocator; @@ -168,7 +166,6 @@ pub(crate) mod hack { } } } - #[cfg(not(test))] impl [T] { /// Sorts the slice. @@ -205,7 +202,7 @@ impl [T] { where T: Ord, { - merge_sort(self, |a, b| a.lt(b)); + stable_sort(self, |a, b| a.lt(b)); } /// Sorts the slice with a comparator function. @@ -969,11 +966,7 @@ where || (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len) || (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len)) { - if n >= 3 && runs[n - 3].len < runs[n - 1].len { - Some(n - 3) - } else { - Some(n - 2) - } + if n >= 3 && runs[n - 3].len < runs[n - 1].len { Some(n - 3) } else { Some(n - 2) } } else { None } @@ -1257,7 +1250,7 @@ where // Consume the greater side. // If equal, prefer the right run to maintain stability. unsafe { - let to_copy = if is_less(&*right.sub(1), &*left.sub(1)) { + let to_copy = if is_less(&*right.offset(-1), &*left.offset(-1)) { decrement_and_get(left) } else { decrement_and_get(right) @@ -1271,12 +1264,12 @@ where unsafe fn get_and_increment(ptr: &mut *mut T) -> *mut T { let old = *ptr; - *ptr = unsafe { ptr.add(1) }; + *ptr = unsafe { ptr.offset(1) }; old } unsafe fn decrement_and_get(ptr: &mut *mut T) -> *mut T { - *ptr = unsafe { ptr.sub(1) }; + *ptr = unsafe { ptr.offset(-1) }; *ptr } @@ -1357,119 +1350,370 @@ unsafe fn swap_next_if_less(arr_ptr: *mut T, is_less: &mut F) where F: FnMut(&T, &T) -> bool, { - // Slices of up to this length get sorted using insertion sort. - const MAX_INSERTION: usize = 20; - // Very short runs are extended using insertion sort to span at least this many elements. - const MIN_RUN: usize = 10; - - // Sorting has no meaningful behavior on zero-sized types. - if size_of::() == 0 { - return; + // SAFETY: the caller must guarantee that `arr_ptr` and `arr_ptr.add(1)` are valid for writes + // and properly aligned. + // + // PANIC SAFETY: if is_less panics, no scratch memory was created and the slice should still be + // in a well defined state, without duplicates. + // + // Important to only swap if it is more and not if it is equal. is_less should return false for + // equal, so we don't swap. + unsafe { + let should_swap = is_less(&*arr_ptr.add(1), &*arr_ptr); + swap_next_if(arr_ptr, should_swap); } +} - let len = v.len(); +/// Sort the first 2 elements of v. +unsafe fn sort2(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + debug_assert!(v.len() >= 2); - // Short arrays get sorted in-place via insertion sort to avoid allocations. - if len <= MAX_INSERTION { - if len >= 2 { - for i in (0..len - 1).rev() { - insert_head(&mut v[i..], &mut is_less); - } - } - return; + // SAFETY: caller must ensure v is at least len 2. + unsafe { + swap_next_if_less(v.as_mut_ptr(), is_less); } +} - // Allocate a buffer to use as scratch memory. We keep the length 0 so we can keep in it - // shallow copies of the contents of `v` without risking the dtors running on copies if - // `is_less` panics. When merging two sorted runs, this buffer holds a copy of the shorter run, - // which will always have length at most `len / 2`. - let mut buf = Vec::with_capacity(len / 2); +/// Sort the first 3 elements of v. +unsafe fn sort3(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + debug_assert!(v.len() >= 3); - // In order to identify natural runs in `v`, we traverse it backwards. That might seem like a - // strange decision, but consider the fact that merges more often go in the opposite direction - // (forwards). According to benchmarks, merging forwards is slightly faster than merging - // backwards. To conclude, identifying runs by traversing backwards improves performance. - let mut runs = vec![]; - let mut end = len; - while end > 0 { - // Find the next natural run, and reverse it if it's strictly descending. - let mut start = end - 1; - if start > 0 { - start -= 1; - unsafe { - if is_less(v.get_unchecked(start + 1), v.get_unchecked(start)) { - while start > 0 && is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) { - start -= 1; - } - v[start..end].reverse(); - } else { - while start > 0 && !is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) - { - start -= 1; - } - } - } - } + // SAFETY: caller must ensure v is at least len 3. + unsafe { + let arr_ptr = v.as_mut_ptr(); + let x1 = arr_ptr; + let x2 = arr_ptr.add(1); - // Insert some more elements into the run if it's too short. Insertion sort is faster than - // merge sort on short sequences, so this significantly improves performance. - while start > 0 && end - start < MIN_RUN { - start -= 1; - insert_head(&mut v[start..end], &mut is_less); - } + swap_next_if_less(x1, is_less); + swap_next_if_less(x2, is_less); - // Push this run onto the stack. - runs.push(Run { start, len: end - start }); - end = start; + // After two swaps we are here: + // + // abc -> ab bc | abc + // acb -> ac bc | abc + // bac -> ab bc | abc + // bca -> bc ac | bac ! + // cab -> ac bc | abc + // cba -> bc ac | bac ! + + // Which means we need to swap again. + swap_next_if_less(x1, is_less); + } +} - // Merge some pairs of adjacent runs to satisfy the invariants. - while let Some(r) = collapse(&runs) { - let left = runs[r + 1]; - let right = runs[r]; - unsafe { - merge( - &mut v[left.start..right.start + right.len], - left.len, - buf.as_mut_ptr(), - &mut is_less, - ); - } - runs[r] = Run { start: left.start, len: left.len + right.len }; - runs.remove(r + 1); +/// Sort the first 4 elements of v. +unsafe fn sort4(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + debug_assert!(v.len() >= 4); + + // SAFETY: caller must ensure v is at least len 4. + unsafe { + let arr_ptr = v.as_mut_ptr(); + let x1 = arr_ptr; + let x2 = arr_ptr.add(1); + let x3 = arr_ptr.add(2); + + swap_next_if_less(x1, is_less); + swap_next_if_less(x3, is_less); + + // PANIC SAFETY: if is_less panics, no scratch memory was created and the slice should still be + // in a well defined state, without duplicates. + if is_less(&*x3, &*x2) { + ptr::swap_nonoverlapping(x2, x3, 1); + + swap_next_if_less(x1, is_less); + swap_next_if_less(x3, is_less); + swap_next_if_less(x2, is_less); } } +} - // Finally, exactly one run must remain in the stack. - debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len); +#[inline] +unsafe fn merge_up( + mut src_left: *const T, + mut src_right: *const T, + mut dest_ptr: *mut T, + is_less: &mut F, +) -> (*const T, *const T, *mut T) +where + F: FnMut(&T, &T) -> bool, +{ + // This is a branchless merge utility function. + // The equivalent code with a branch would be: + // + // if is_less(&*src_right, &*src_left) { + // // x == 0 && y == 1 + // // Elements should be swapped in final order. + // + // // Copy right side into dest[0] and the left side into dest[1] + // ptr::copy_nonoverlapping(src_right, dest_ptr, 1); + // ptr::copy_nonoverlapping(src_left, dest_ptr.add(1), 1); + // + // // Move right cursor one further, because we swapped. + // src_right = src_right.add(1); + // } else { + // // x == 1 && y == 0 + // // Elements are in order and don't need to be swapped. + // + // // Copy left side into dest[0] and the right side into dest[1] + // ptr::copy_nonoverlapping(src_left, dest_ptr, 1); + // ptr::copy_nonoverlapping(src_right, dest_ptr.add(1), 1); + // + // // Move left cursor one further, because we didn't swap. + // src_left = src_left.add(1); + // } + // + // dest_ptr = dest_ptr.add(1); - // Examines the stack of runs and identifies the next pair of runs to merge. More specifically, - // if `Some(r)` is returned, that means `runs[r]` and `runs[r + 1]` must be merged next. If the - // algorithm should continue building a new run instead, `None` is returned. + // SAFETY: The caller must ensure that src_left and src_right are valid to read. And that + // dest_ptr and dest_ptr.add(1) are valid for writes. Also src and dest must not alias. + unsafe { + let x = !is_less(&*src_right, &*src_left); + let y = !x; + ptr::copy_nonoverlapping(src_right, dest_ptr.add(x as usize), 1); + ptr::copy_nonoverlapping(src_left, dest_ptr.add(y as usize), 1); + src_right = src_right.add(y as usize); + src_left = src_left.add(x as usize); + dest_ptr = dest_ptr.add(1); + } + + (src_left, src_right, dest_ptr) +} + +#[inline] +unsafe fn merge_down( + mut src_left: *const T, + mut src_right: *const T, + mut dest_ptr: *mut T, + is_less: &mut F, +) -> (*const T, *const T, *mut T) +where + F: FnMut(&T, &T) -> bool, +{ + // This is a branchless merge utility function. + // The equivalent code with a branch would be: // - // TimSort is infamous for its buggy implementations, as described here: - // http://envisage-project.eu/timsort-specification-and-verification/ + // dest_ptr = dest_ptr.sub(1); // - // The gist of the story is: we must enforce the invariants on the top four runs on the stack. - // Enforcing them on just top three is not sufficient to ensure that the invariants will still - // hold for *all* runs in the stack. + // if is_less(&*src_right, &*src_left) { + // // x == 0 && y == 1 + // // Elements should be swapped in final order. // - // This function correctly checks invariants for the top four runs. Additionally, if the top - // run starts at index 0, it will always demand a merge operation until the stack is fully - // collapsed, in order to complete the sort. - #[inline] - fn collapse(runs: &[Run]) -> Option { - let n = runs.len(); - if n >= 2 - && (runs[n - 1].start == 0 - || runs[n - 2].len <= runs[n - 1].len - || (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len) - || (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len)) - { - if n >= 3 && runs[n - 3].len < runs[n - 1].len { Some(n - 3) } else { Some(n - 2) } - } else { - None + // // Copy right side into dest[0] and the left side into dest[1] + // ptr::copy_nonoverlapping(src_right, dest_ptr, 1); + // ptr::copy_nonoverlapping(src_left, dest_ptr.add(1), 1); + // + // // Move left cursor one back, because we swapped. + // src_left = src_left.sub(1); + // } else { + // // x == 1 && y == 0 + // // Elements are in order and don't need to be swapped. + // + // // Copy left side into dest[0] and the right side into dest[1] + // ptr::copy_nonoverlapping(src_left, dest_ptr, 1); + // ptr::copy_nonoverlapping(src_right, dest_ptr.add(1), 1); + // + // // Move right cursor one back, because we didn't swap. + // src_right = src_right.sub(1); + // } + + // SAFETY: The caller must ensure that src_left and src_right are valid to read. And that + // dest_ptr and dest_ptr.sub(1) are valid for writes. Also src and dest must not alias. + unsafe { + let x = !is_less(&*src_right, &*src_left); + let y = !x; + dest_ptr = dest_ptr.sub(1); + ptr::copy_nonoverlapping(src_right, dest_ptr.add(x as usize), 1); + ptr::copy_nonoverlapping(src_left, dest_ptr.add(y as usize), 1); + src_right = src_right.sub(x as usize); + src_left = src_left.sub(y as usize); + } + + (src_left, src_right, dest_ptr) +} + +#[inline] +unsafe fn finish_up( + src_left: *const T, + src_right: *const T, + dest_ptr: *mut T, + is_less: &mut F, +) where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: The caller must ensure that src_left and src_right are valid to + // read. And that dest_ptr is valid to write. Also src and dest must not alias. + unsafe { + let copy_ptr = if is_less(&*src_right, &*src_left) { src_right } else { src_left }; + ptr::copy_nonoverlapping(copy_ptr, dest_ptr, 1); + } +} + +#[inline] +unsafe fn finish_down( + src_left: *const T, + src_right: *const T, + dest_ptr: *mut T, + is_less: &mut F, +) where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: The caller must ensure that src_left and src_right are valid to + // read. And that dest_ptr is valid to write. Also src and dest must not alias. + unsafe { + let copy_ptr = if is_less(&*src_right, &*src_left) { src_left } else { src_right }; + ptr::copy_nonoverlapping(copy_ptr, dest_ptr, 1); + } +} + +unsafe fn parity_merge8(src_ptr: *const T, dest_ptr: *mut T, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: the caller must guarantee that `arr_ptr` and `dest_ptr` are valid for writes and + // properly aligned. And they point to a contiguous owned region of memory each at least 8 + // elements long. Also `src_ptr` and `dest_ptr` must not alias. + // + // The caller must guarantee that the values of `dest_ptr[0..len]` have trivial + // destructors that are sound to be called on a shallow copy of T. + unsafe { + // Eg. src == [2, 3, 4, 7, 0, 1, 6, 8] + // + // in: ptr_left = src[0] ptr_right = src[4] ptr_dest = dest[0] + // mu: ptr_left = src[0] ptr_right = src[5] ptr_dest = dest[1] dest == [0, 2, u, u, u, u, u, u] + // mu: ptr_left = src[0] ptr_right = src[6] ptr_dest = dest[2] dest == [0, 1, 2, u, u, u, u, u] + // mu: ptr_left = src[1] ptr_right = src[6] ptr_dest = dest[3] dest == [0, 1, 2, 6, u, u, u, u] + // fu: dest == [0, 1, 2, 3, u, u, u, u] + // in: ptr_left = src[3] ptr_right = src[7] ptr_dest = dest[7] + // md: ptr_left = src[3] ptr_right = src[6] ptr_dest = dest[6] dest == [0, 1, 2, 6, u, u, 7, 8] + // md: ptr_left = src[2] ptr_right = src[6] ptr_dest = dest[5] dest == [0, 1, 2, 6, u, 6, 7, 8] + // md: ptr_left = src[2] ptr_right = src[5] ptr_dest = dest[4] dest == [0, 1, 2, 3, 4, 6, 7, 8] + // fd: dest == [0, 1, 2, 3, 4, 6, 7, 8] + + let mut ptr_left = src_ptr; + let mut ptr_right = src_ptr.add(4); + let mut ptr_dest = dest_ptr; + + (ptr_left, ptr_right, ptr_dest) = merge_up(ptr_left, ptr_right, ptr_dest, is_less); + (ptr_left, ptr_right, ptr_dest) = merge_up(ptr_left, ptr_right, ptr_dest, is_less); + (ptr_left, ptr_right, ptr_dest) = merge_up(ptr_left, ptr_right, ptr_dest, is_less); + + finish_up(ptr_left, ptr_right, ptr_dest, is_less); + + // --- + + ptr_left = src_ptr.add(3); + ptr_right = src_ptr.add(7); + ptr_dest = dest_ptr.add(7); + + (ptr_left, ptr_right, ptr_dest) = merge_down(ptr_left, ptr_right, ptr_dest, is_less); + (ptr_left, ptr_right, ptr_dest) = merge_down(ptr_left, ptr_right, ptr_dest, is_less); + (ptr_left, ptr_right, ptr_dest) = merge_down(ptr_left, ptr_right, ptr_dest, is_less); + + finish_down(ptr_left, ptr_right, ptr_dest, is_less); + } +} + +unsafe fn parity_merge(src_ptr: *const T, dest_ptr: *mut T, len: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: the caller must guarantee that `src_ptr` and `dest_ptr` are valid for writes and + // properly aligned. And they point to a contiguous owned region of memory each at least len + // elements long. Also `src_ptr` and `dest_ptr` must not alias. + // + // The caller must guarantee that the values of `dest_ptr[0..len]` have trivial + // destructors that are sound to be called on a shallow copy of T. + unsafe { + let mut block = len / 2; + + let mut ptr_left = src_ptr; + let mut ptr_right = src_ptr.add(block); + let mut ptr_data = dest_ptr; + + let mut t_ptr_left = src_ptr.add(block - 1); + let mut t_ptr_right = src_ptr.add(len - 1); + let mut t_ptr_data = dest_ptr.add(len - 1); + + block -= 1; + while block != 0 { + (ptr_left, ptr_right, ptr_data) = merge_up(ptr_left, ptr_right, ptr_data, is_less); + (t_ptr_left, t_ptr_right, t_ptr_data) = + merge_down(t_ptr_left, t_ptr_right, t_ptr_data, is_less); + + block -= 1; + } + + finish_up(ptr_left, ptr_right, ptr_data, is_less); + finish_down(t_ptr_left, t_ptr_right, t_ptr_data, is_less); + } +} + +// This implementation is only valid for Copy types. For these reasons: +// 1. Panic safety +// 2. Uniqueness preservation for types with interior mutability. +unsafe fn sort8(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + debug_assert!(v.len() == 8); + + // SAFETY: caller must ensure v is at least len 8. + unsafe { + sort4(v, is_less); + sort4(&mut v[4..], is_less); + + let arr_ptr = v.as_mut_ptr(); + if !is_less(&*arr_ptr.add(4), &*arr_ptr.add(3)) { + return; } + + let mut swap = mem::MaybeUninit::<[T; 8]>::uninit(); + let swap_ptr = swap.as_mut_ptr() as *mut T; + + // We know the two parts v[0..4] and v[4..8] are already sorted. Merge into swap_ptr, so + // that a panic of is_less leaves arr_ptr unchanged and in a valid state, preserving all + // original elements. + parity_merge8(arr_ptr, swap_ptr, is_less); + // Once the merge is done, copy everything from swap to arr. + ptr::copy_nonoverlapping(swap_ptr, arr_ptr, 8); } +} + +// This implementation is only valid for Copy types. For these reasons: +// 1. Panic safety +// 2. Uniqueness preservation for types with interior mutability. +unsafe fn sort16(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + debug_assert!(v.len() == 16); + + // SAFETY: caller must ensure v is at least len 16. + unsafe { + // Sort the 4 parts of v individually. + sort4(v, is_less); + sort4(&mut v[4..], is_less); + sort4(&mut v[8..], is_less); + sort4(&mut v[12..], is_less); + + // If all 3 pairs of border elements are sorted, we know the whole 16 elements are now sorted. + // Doing this check reduces the total comparisons done on average for different input patterns. + let arr_ptr = v.as_mut_ptr(); + if !is_less(&*arr_ptr.add(4), &*arr_ptr.add(3)) + && !is_less(&*arr_ptr.add(8), &*arr_ptr.add(7)) + && !is_less(&*arr_ptr.add(12), &*arr_ptr.add(11)) + { + return; + } let mut swap = mem::MaybeUninit::<[T; 16]>::uninit(); let swap_ptr = swap.as_mut_ptr() as *mut T; From 05a6b9c3ddabb194f0591f994eac71bf03babc12 Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Wed, 26 Oct 2022 00:20:29 +0200 Subject: [PATCH 3/5] Rework new_stable_sort with sorting-network --- library/alloc/src/slice.rs | 780 +++++++++++++++---------------------- 1 file changed, 314 insertions(+), 466 deletions(-) diff --git a/library/alloc/src/slice.rs b/library/alloc/src/slice.rs index 7202b69131361..882559b33fd05 100644 --- a/library/alloc/src/slice.rs +++ b/library/alloc/src/slice.rs @@ -822,48 +822,88 @@ where merge_sort(v, &mut is_less); } -// Slices of up to this length get sorted using insertion sort. -const MAX_INSERTION: usize = 20; - // Sort a small number of elements as fast as possible, without allocations. -#[inline] #[cfg(not(no_global_oom_handling))] -fn sort_small(v: &mut [T], is_less: &mut F) +fn stable_sort_small(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, { let len = v.len(); + // This implementation is really not fit for anything beyond that, and the call is probably a + // bug. + debug_assert!(len <= 40); + if len < 2 { return; } - if T::is_copy() { - // SAFETY: We check the corresponding min len for sortX. + // It's not clear that using custom code for specific sizes is worth it here. + // So we go with the simpler code. + let offset = if len <= 6 || !qualifies_for_branchless_sort::() { + 1 + } else { + // Once a certain threshold is reached, it becomes worth it to analyze the input and do + // branchless swapping for the first 5 elements. + + // SAFETY: We just checked that len >= 5 unsafe { - if len == 2 { - sort2(v, is_less); - } else if len == 3 { - sort3(v, is_less); - } else if len < 8 { - sort4(&mut v[..4], is_less); - insertion_sort_remaining(v, 4, is_less); - } else if len < 16 { - sort8(&mut v[..8], is_less); - insertion_sort_remaining(v, 8, is_less); + let arr_ptr = v.as_mut_ptr(); + + let should_swap_0_1 = is_less(&*arr_ptr.add(1), &*arr_ptr.add(0)); + let should_swap_1_2 = is_less(&*arr_ptr.add(2), &*arr_ptr.add(1)); + let should_swap_2_3 = is_less(&*arr_ptr.add(3), &*arr_ptr.add(2)); + let should_swap_3_4 = is_less(&*arr_ptr.add(4), &*arr_ptr.add(3)); + + let swap_count = should_swap_0_1 as usize + + should_swap_1_2 as usize + + should_swap_2_3 as usize + + should_swap_3_4 as usize; + + if swap_count == 0 { + // Potentially already sorted. No need to swap, we know the first 5 elements are + // already in the right order. + 5 + } else if swap_count == 4 { + // Potentially reversed. + let mut rev_i = 4; + while rev_i < (len - 1) { + if !is_less(&*arr_ptr.add(rev_i + 1), &*arr_ptr.add(rev_i)) { + break; + } + rev_i += 1; + } + rev_i += 1; + v[..rev_i].reverse(); + insertion_sort_shift_left(v, rev_i, is_less); + return; } else { - sort16(&mut v[..16], is_less); - insertion_sort_remaining(v, 16, is_less); - } - } - } else { - for i in (0..len - 1).rev() { - // SAFETY: We checked above that len is at least 2. - unsafe { - insert_head(&mut v[i..], is_less); + // Potentially random pattern. + branchless_swap(arr_ptr.add(0), arr_ptr.add(1), should_swap_0_1); + branchless_swap(arr_ptr.add(2), arr_ptr.add(3), should_swap_2_3); + + if len >= 12 { + // This aims to find a good balance between generating more code, which is bad + // for cold loops and improving hot code while not increasing mean comparison + // count too much. + sort8_stable(&mut v[4..12], is_less); + insertion_sort_shift_left(&mut v[4..], 8, is_less); + insertion_sort_shift_right(v, 4, is_less); + return; + } else { + // Complete the sort network for the first 4 elements. + swap_next_if_less(arr_ptr.add(1), is_less); + swap_next_if_less(arr_ptr.add(2), is_less); + swap_next_if_less(arr_ptr.add(0), is_less); + swap_next_if_less(arr_ptr.add(1), is_less); + + 4 + } } } - } + }; + + insertion_sort_shift_left(v, offset, is_less); } #[cfg(not(no_global_oom_handling))] @@ -878,17 +918,19 @@ where let len = v.len(); + // Slices of up to this length get sorted using insertion sort. + const MAX_NO_ALLOC_SIZE: usize = 20; + // Short arrays get sorted in-place via insertion sort to avoid allocations. - if len <= MAX_INSERTION { - sort_small(v, is_less); + if len <= MAX_NO_ALLOC_SIZE { + stable_sort_small(v, is_less); return; } - // Allocate a buffer to use as scratch memory. We keep the length 0 so we can keep in it - // shallow copies of the contents of `v` without risking the dtors running on copies if - // `is_less` panics. When merging two sorted runs, this buffer holds a copy of the shorter run, - // which will always have length at most `len / 2`. - let mut buf = Vec::with_capacity(len / 2); + // Don't allocate right at the beginning, wait to see if the slice is already sorted or + // reversed. + let mut buf; + let mut buf_ptr: *mut T = ptr::null_mut(); // In order to identify natural runs in `v`, we traverse it backwards. That might seem like a // strange decision, but consider the fact that merges more often go in the opposite direction @@ -916,6 +958,19 @@ where } } + if start == 0 && end == len { + // The input was either fully ascending or descending. It is now sorted and we can + // return without allocating. + return; + } else if buf_ptr.is_null() { + // Allocate a buffer to use as scratch memory. We keep the length 0 so we can keep in it + // shallow copies of the contents of `v` without risking the dtors running on copies if + // `is_less` panics. When merging two sorted runs, this buffer holds a copy of the + // shorter run, which will always have length at most `len / 2`. + buf = Vec::with_capacity(len / 2); + buf_ptr = buf.as_mut_ptr(); + } + // SAFETY: end > start. start = provide_sorted_batch(v, start, end, is_less); @@ -928,12 +983,7 @@ where let left = runs[r + 1]; let right = runs[r]; unsafe { - merge( - &mut v[left.start..right.start + right.len], - left.len, - buf.as_mut_ptr(), - is_less, - ); + merge(&mut v[left.start..right.start + right.len], left.len, buf_ptr, is_less); } runs[r] = Run { start: left.start, len: left.len + right.len }; runs.remove(r + 1); @@ -966,7 +1016,11 @@ where || (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len) || (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len)) { - if n >= 3 && runs[n - 3].len < runs[n - 1].len { Some(n - 3) } else { Some(n - 2) } + if n >= 3 && runs[n - 3].len < runs[n - 1].len { + Some(n - 3) + } else { + Some(n - 2) + } } else { None } @@ -986,11 +1040,8 @@ fn provide_sorted_batch(v: &mut [T], mut start: usize, end: usize, is_less where F: FnMut(&T, &T) -> bool, { - // Not doing so is a logic bug, but not a safety bug. debug_assert!(end > start); - const MAX_PRE_SORT16: usize = 8; - // Testing showed that using MAX_INSERTION here yields the best performance for many types, but // incurs more total comparisons. A balance between least comparisons and best performance, as // influenced by for example cache locality. @@ -1001,24 +1052,48 @@ where let start_found = start; let start_end_diff = end - start; - if T::is_copy() && start_end_diff < MAX_PRE_SORT16 && start_found >= 16 { - // SAFETY: We just checked that start_found is >= 16. + const FAST_SORT_SIZE: usize = 24; + + if qualifies_for_branchless_sort::() && end >= (FAST_SORT_SIZE + 3) && start_end_diff <= 6 { + // For random inputs on average how many elements are naturally already sorted + // (start_end_diff) will be relatively small. And it's faster to avoid a merge operation + // between the newly sorted elements on the left by the sort network and the already sorted + // elements. Instead if there are 3 or fewer already sorted elements they get merged by + // participating in the sort network. This wastes the information that they are already + // sorted, but extra branching is not worth it. + // + // Note, this optimization significantly reduces comparison count, versus just always using + // insertion_sort_shift_left. Insertion sort is faster than calling merge here, and this is + // yet faster starting at FAST_SORT_SIZE 20. + let is_small_pre_sorted = start_end_diff <= 3; + + start = if is_small_pre_sorted { + end - FAST_SORT_SIZE + } else { + start_found - (FAST_SORT_SIZE - 3) + }; + + // SAFETY: start >= 0 && start + FAST_SORT_SIZE <= end unsafe { - start = start_found.unchecked_sub(16); - sort16(&mut v[start..start_found], is_less); + // Use a straight-line sorting network here instead of some hybrid network with early + // exit. If the input is already sorted the previous adaptive analysis path of TimSort + // ought to have found it. So we prefer minimizing the total amount of comparisons, + // which are user provided and may be of arbitrary cost. + sort24_stable(&mut v[start..(start + FAST_SORT_SIZE)], is_less); } - insertion_sort_remaining(&mut v[start..end], 16, is_less); - } else if start_end_diff < MIN_INSERTION_RUN { - start = start.saturating_sub(MIN_INSERTION_RUN - start_end_diff); - - for i in (start..start_found).rev() { - // SAFETY: We ensured that the slice length is always at lest 2 long. - // We know that start_found will be at least one less than end, - // and the range is exclusive. Which gives us i always <= (end - 2). - unsafe { - insert_head(&mut v[i..end], is_less); - } + + // For most patterns this branch should have good prediction accuracy. + if !is_small_pre_sorted { + insertion_sort_shift_left(&mut v[start..end], FAST_SORT_SIZE, is_less); } + } else if start_end_diff < MIN_INSERTION_RUN && start != 0 { + // v[start_found..end] are elements that are already sorted in the input. We want to extend + // the sorted region to the left, so we push up MIN_INSERTION_RUN - 1 to the right. Which is + // more efficient that trying to push those already sorted elements to the left. + + start = if end >= MIN_INSERTION_RUN { end - MIN_INSERTION_RUN } else { 0 }; + + insertion_sort_shift_right(&mut v[start..end], start_found - start, is_less); } start @@ -1032,41 +1107,15 @@ struct InsertionHole { impl Drop for InsertionHole { fn drop(&mut self) { - // SAFETY: caller must ensure src is valid to read and dest is valid to write. They must not - // alias. unsafe { ptr::copy_nonoverlapping(self.src, self.dest, 1); } } } -/// Sort v assuming v[..offset] is already sorted. -#[inline] -#[cfg(not(no_global_oom_handling))] -fn insertion_sort_remaining(v: &mut [T], offset: usize, is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - let len = v.len(); - - // This is a logic but not a safety bug. - debug_assert!(offset != 0 && offset <= len); - - if len < 2 || offset == 0 { - return; - } - - // Shift each element of the unsorted region v[i..] as far left as is needed to make v sorted. - for i in offset..len { - insert_tail(&mut v[..=i], is_less); - } -} - /// Inserts `v[v.len() - 1]` into pre-sorted sequence `v[..v.len() - 1]` so that whole `v[..]` /// becomes sorted. -#[inline] -#[cfg(not(no_global_oom_handling))] -fn insert_tail(v: &mut [T], is_less: &mut F) +unsafe fn insert_tail(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, { @@ -1075,6 +1124,7 @@ where let arr_ptr = v.as_mut_ptr(); let i = v.len() - 1; + // SAFETY: caller must ensure v is at least len 2. unsafe { // See insert_head which talks about why this approach is beneficial. let i_ptr = arr_ptr.add(i); @@ -1110,27 +1160,82 @@ where break; } + ptr::copy_nonoverlapping(j_ptr, hole.dest, 1); hole.dest = j_ptr; - ptr::copy_nonoverlapping(hole.dest, j_ptr.add(1), 1); } // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. } } +/// Sort v assuming v[..offset] is already sorted. +/// +/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +/// performance impact. Even improving performance in some cases. +#[inline(never)] +fn insertion_sort_shift_left(v: &mut [T], offset: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + // This is a logic but not a safety bug. + debug_assert!(offset != 0 && offset <= len); + + if ((len < 2) as u8 + (offset == 0) as u8) != 0 { + return; + } + + // Shift each element of the unsorted region v[i..] as far left as is needed to make v sorted. + for i in offset..len { + // SAFETY: we tested that len >= 2. + unsafe { + // Maybe use insert_head here and avoid additional code. + insert_tail(&mut v[..=i], is_less); + } + } +} + +/// Sort v assuming v[offset..] is already sorted. +/// +/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +/// performance impact. Even improving performance in some cases. +#[inline(never)] +fn insertion_sort_shift_right(v: &mut [T], offset: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + // This is a logic but not a safety bug. + debug_assert!(offset != 0 && offset <= len); + + if ((len < 2) as u8 + (offset == 0) as u8) != 0 { + return; + } + + // Shift each element of the unsorted region v[..i] as far left as is needed to make v sorted. + for i in (0..offset).rev() { + // We ensured that the slice length is always at least 2 long. + // We know that start_found will be at least one less than end, + // and the range is exclusive. Which gives us i always <= (end - 2). + unsafe { + insert_head(&mut v[i..len], is_less); + } + } +} + /// Inserts `v[0]` into pre-sorted sequence `v[1..]` so that whole `v[..]` becomes sorted. /// /// This is the integral subroutine of insertion sort. -#[inline] -#[cfg(not(no_global_oom_handling))] unsafe fn insert_head(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, { debug_assert!(v.len() >= 2); - // SAFETY: caller must ensure v is at least len 2. - unsafe { - if is_less(&v[1], &v[0]) { + if is_less(&v[1], &v[0]) { + // SAFETY: caller must ensure v is at least len 2. + unsafe { // There are three ways to implement insertion here: // // 1. Swap adjacent elements until the first one gets to its final destination. @@ -1291,446 +1396,189 @@ where } } -trait IsCopy { +#[rustc_unsafe_specialization_marker] +trait IsCopyMarker {} + +impl IsCopyMarker for T {} + +trait IsCopy { fn is_copy() -> bool; } -impl IsCopy for T { +impl IsCopy for T { + default fn is_copy() -> bool { + false + } +} + +impl IsCopy for T { fn is_copy() -> bool { - // FIXME, heuristic loss and uniqueness preservation bug. true } } -// FIXME! -// impl IsCopy for T { -// fn is_copy() -> bool { -// true -// } -// } +#[inline] +fn qualifies_for_branchless_sort() -> bool { + // This is a heuristic, and as such it will guess wrong from time to time. The two parts broken + // down: + // + // - Copy: We guess that copy types have relatively cheap comparison functions. The branchless + // sort does on average 8% more comparisons for random inputs and up to 50% in some + // circumstances. The time won avoiding branches can be offset by this increase in + // comparisons if the type is expensive to compare. + // + // - Type size: Large types are more expensive to move and the time won avoiding branches can be + // offset by the increased cost of moving the values. + T::is_copy() && (mem::size_of::() <= mem::size_of::<[usize; 4]>()) +} // --- Branchless sorting (less branches not zero) --- -/// Swap value with next value in array pointed to by arr_ptr if should_swap is true. +/// Swap two values in array pointed to by a_ptr and b_ptr if b is less than a. #[inline] -unsafe fn swap_next_if(arr_ptr: *mut T, should_swap: bool) { +unsafe fn branchless_swap(a_ptr: *mut T, b_ptr: *mut T, should_swap: bool) { // This is a branchless version of swap if. // The equivalent code with a branch would be: // // if should_swap { - // ptr::swap_nonoverlapping(arr_ptr, arr_ptr.add(1), 1) - // } - // - // Be mindful in your benchmarking that this only starts to outperform branching code if the - // benchmark doesn't execute the same branches again and again. + // ptr::swap_nonoverlapping(a_ptr, b_ptr, 1); // } - // // Give ourselves some scratch space to work with. // We do not have to worry about drops: `MaybeUninit` does nothing when dropped. let mut tmp = mem::MaybeUninit::::uninit(); - // Perform the conditional swap. - // SAFETY: the caller must guarantee that `arr_ptr` and `arr_ptr.add(1)` are - // valid for writes and properly aligned. `tmp` cannot be overlapping either `arr_ptr` or - // `arr_ptr.add(1) because `tmp` was just allocated on the stack as a separate allocated object. - // And `arr_ptr` and `arr_ptr.add(1)` can't overlap either. - // However `arr_ptr` and `arr_ptr.add(should_swap as usize)` can point to the same memory if - // should_swap is false. - unsafe { - ptr::copy_nonoverlapping(arr_ptr.add(!should_swap as usize), tmp.as_mut_ptr(), 1); - ptr::copy(arr_ptr.add(should_swap as usize), arr_ptr, 1); - ptr::copy_nonoverlapping(tmp.as_ptr(), arr_ptr.add(1), 1); - } -} + // The goal is to generate cmov instructions here. + let a_swap_ptr = if should_swap { b_ptr } else { a_ptr }; + let b_swap_ptr = if should_swap { a_ptr } else { b_ptr }; -/// Swap value with next value in array pointed to by arr_ptr if should_swap is true. -#[inline] -unsafe fn swap_next_if_less(arr_ptr: *mut T, is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - // SAFETY: the caller must guarantee that `arr_ptr` and `arr_ptr.add(1)` are valid for writes - // and properly aligned. - // - // PANIC SAFETY: if is_less panics, no scratch memory was created and the slice should still be - // in a well defined state, without duplicates. - // - // Important to only swap if it is more and not if it is equal. is_less should return false for - // equal, so we don't swap. + // SAFETY: the caller must guarantee that `a_ptr` and `b_ptr` are valid for writes + // and properly aligned, and part of the same allocation, and do not alias. unsafe { - let should_swap = is_less(&*arr_ptr.add(1), &*arr_ptr); - swap_next_if(arr_ptr, should_swap); - } -} - -/// Sort the first 2 elements of v. -unsafe fn sort2(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - debug_assert!(v.len() >= 2); - - // SAFETY: caller must ensure v is at least len 2. - unsafe { - swap_next_if_less(v.as_mut_ptr(), is_less); - } -} - -/// Sort the first 3 elements of v. -unsafe fn sort3(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - debug_assert!(v.len() >= 3); - - // SAFETY: caller must ensure v is at least len 3. - unsafe { - let arr_ptr = v.as_mut_ptr(); - let x1 = arr_ptr; - let x2 = arr_ptr.add(1); - - swap_next_if_less(x1, is_less); - swap_next_if_less(x2, is_less); - - // After two swaps we are here: - // - // abc -> ab bc | abc - // acb -> ac bc | abc - // bac -> ab bc | abc - // bca -> bc ac | bac ! - // cab -> ac bc | abc - // cba -> bc ac | bac ! - - // Which means we need to swap again. - swap_next_if_less(x1, is_less); + ptr::copy_nonoverlapping(b_swap_ptr, tmp.as_mut_ptr(), 1); + ptr::copy(a_swap_ptr, a_ptr, 1); + ptr::copy_nonoverlapping(tmp.as_ptr(), b_ptr, 1); } } -/// Sort the first 4 elements of v. -unsafe fn sort4(v: &mut [T], is_less: &mut F) +/// Swap two values in array pointed to by a_ptr and b_ptr if b is less than a. +#[inline] +unsafe fn swap_if_less(arr_ptr: *mut T, a: usize, b: usize, is_less: &mut F) where F: FnMut(&T, &T) -> bool, { - debug_assert!(v.len() >= 4); - - // SAFETY: caller must ensure v is at least len 4. + // SAFETY: the caller must guarantee that `a` and `b` each added to `arr_ptr` yield valid + // pointers into `arr_ptr`. and properly aligned, and part of the same allocation, and do not + // alias. `a` and `b` must be different numbers. unsafe { - let arr_ptr = v.as_mut_ptr(); - let x1 = arr_ptr; - let x2 = arr_ptr.add(1); - let x3 = arr_ptr.add(2); + debug_assert!(a != b); - swap_next_if_less(x1, is_less); - swap_next_if_less(x3, is_less); + let a_ptr = arr_ptr.add(a); + let b_ptr = arr_ptr.add(b); // PANIC SAFETY: if is_less panics, no scratch memory was created and the slice should still be // in a well defined state, without duplicates. - if is_less(&*x3, &*x2) { - ptr::swap_nonoverlapping(x2, x3, 1); - - swap_next_if_less(x1, is_less); - swap_next_if_less(x3, is_less); - swap_next_if_less(x2, is_less); - } - } -} -#[inline] -unsafe fn merge_up( - mut src_left: *const T, - mut src_right: *const T, - mut dest_ptr: *mut T, - is_less: &mut F, -) -> (*const T, *const T, *mut T) -where - F: FnMut(&T, &T) -> bool, -{ - // This is a branchless merge utility function. - // The equivalent code with a branch would be: - // - // if is_less(&*src_right, &*src_left) { - // // x == 0 && y == 1 - // // Elements should be swapped in final order. - // - // // Copy right side into dest[0] and the left side into dest[1] - // ptr::copy_nonoverlapping(src_right, dest_ptr, 1); - // ptr::copy_nonoverlapping(src_left, dest_ptr.add(1), 1); - // - // // Move right cursor one further, because we swapped. - // src_right = src_right.add(1); - // } else { - // // x == 1 && y == 0 - // // Elements are in order and don't need to be swapped. - // - // // Copy left side into dest[0] and the right side into dest[1] - // ptr::copy_nonoverlapping(src_left, dest_ptr, 1); - // ptr::copy_nonoverlapping(src_right, dest_ptr.add(1), 1); - // - // // Move left cursor one further, because we didn't swap. - // src_left = src_left.add(1); - // } - // - // dest_ptr = dest_ptr.add(1); + // Important to only swap if it is more and not if it is equal. is_less should return false for + // equal, so we don't swap. + let should_swap = is_less(&*b_ptr, &*a_ptr); - // SAFETY: The caller must ensure that src_left and src_right are valid to read. And that - // dest_ptr and dest_ptr.add(1) are valid for writes. Also src and dest must not alias. - unsafe { - let x = !is_less(&*src_right, &*src_left); - let y = !x; - ptr::copy_nonoverlapping(src_right, dest_ptr.add(x as usize), 1); - ptr::copy_nonoverlapping(src_left, dest_ptr.add(y as usize), 1); - src_right = src_right.add(y as usize); - src_left = src_left.add(x as usize); - dest_ptr = dest_ptr.add(1); + branchless_swap(a_ptr, b_ptr, should_swap); } - - (src_left, src_right, dest_ptr) } +/// Comparing and swapping anything but adjacent elements will yield a non stable sort. +/// So this must be fundamental building block for stable sorting networks. #[inline] -unsafe fn merge_down( - mut src_left: *const T, - mut src_right: *const T, - mut dest_ptr: *mut T, - is_less: &mut F, -) -> (*const T, *const T, *mut T) -where - F: FnMut(&T, &T) -> bool, -{ - // This is a branchless merge utility function. - // The equivalent code with a branch would be: - // - // dest_ptr = dest_ptr.sub(1); - // - // if is_less(&*src_right, &*src_left) { - // // x == 0 && y == 1 - // // Elements should be swapped in final order. - // - // // Copy right side into dest[0] and the left side into dest[1] - // ptr::copy_nonoverlapping(src_right, dest_ptr, 1); - // ptr::copy_nonoverlapping(src_left, dest_ptr.add(1), 1); - // - // // Move left cursor one back, because we swapped. - // src_left = src_left.sub(1); - // } else { - // // x == 1 && y == 0 - // // Elements are in order and don't need to be swapped. - // - // // Copy left side into dest[0] and the right side into dest[1] - // ptr::copy_nonoverlapping(src_left, dest_ptr, 1); - // ptr::copy_nonoverlapping(src_right, dest_ptr.add(1), 1); - // - // // Move right cursor one back, because we didn't swap. - // src_right = src_right.sub(1); - // } - - // SAFETY: The caller must ensure that src_left and src_right are valid to read. And that - // dest_ptr and dest_ptr.sub(1) are valid for writes. Also src and dest must not alias. - unsafe { - let x = !is_less(&*src_right, &*src_left); - let y = !x; - dest_ptr = dest_ptr.sub(1); - ptr::copy_nonoverlapping(src_right, dest_ptr.add(x as usize), 1); - ptr::copy_nonoverlapping(src_left, dest_ptr.add(y as usize), 1); - src_right = src_right.sub(x as usize); - src_left = src_left.sub(y as usize); - } - - (src_left, src_right, dest_ptr) -} - -#[inline] -unsafe fn finish_up( - src_left: *const T, - src_right: *const T, - dest_ptr: *mut T, - is_less: &mut F, -) where - F: FnMut(&T, &T) -> bool, -{ - // SAFETY: The caller must ensure that src_left and src_right are valid to - // read. And that dest_ptr is valid to write. Also src and dest must not alias. - unsafe { - let copy_ptr = if is_less(&*src_right, &*src_left) { src_right } else { src_left }; - ptr::copy_nonoverlapping(copy_ptr, dest_ptr, 1); - } -} - -#[inline] -unsafe fn finish_down( - src_left: *const T, - src_right: *const T, - dest_ptr: *mut T, - is_less: &mut F, -) where - F: FnMut(&T, &T) -> bool, -{ - // SAFETY: The caller must ensure that src_left and src_right are valid to - // read. And that dest_ptr is valid to write. Also src and dest must not alias. - unsafe { - let copy_ptr = if is_less(&*src_right, &*src_left) { src_left } else { src_right }; - ptr::copy_nonoverlapping(copy_ptr, dest_ptr, 1); - } -} - -unsafe fn parity_merge8(src_ptr: *const T, dest_ptr: *mut T, is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - // SAFETY: the caller must guarantee that `arr_ptr` and `dest_ptr` are valid for writes and - // properly aligned. And they point to a contiguous owned region of memory each at least 8 - // elements long. Also `src_ptr` and `dest_ptr` must not alias. - // - // The caller must guarantee that the values of `dest_ptr[0..len]` have trivial - // destructors that are sound to be called on a shallow copy of T. - unsafe { - // Eg. src == [2, 3, 4, 7, 0, 1, 6, 8] - // - // in: ptr_left = src[0] ptr_right = src[4] ptr_dest = dest[0] - // mu: ptr_left = src[0] ptr_right = src[5] ptr_dest = dest[1] dest == [0, 2, u, u, u, u, u, u] - // mu: ptr_left = src[0] ptr_right = src[6] ptr_dest = dest[2] dest == [0, 1, 2, u, u, u, u, u] - // mu: ptr_left = src[1] ptr_right = src[6] ptr_dest = dest[3] dest == [0, 1, 2, 6, u, u, u, u] - // fu: dest == [0, 1, 2, 3, u, u, u, u] - // in: ptr_left = src[3] ptr_right = src[7] ptr_dest = dest[7] - // md: ptr_left = src[3] ptr_right = src[6] ptr_dest = dest[6] dest == [0, 1, 2, 6, u, u, 7, 8] - // md: ptr_left = src[2] ptr_right = src[6] ptr_dest = dest[5] dest == [0, 1, 2, 6, u, 6, 7, 8] - // md: ptr_left = src[2] ptr_right = src[5] ptr_dest = dest[4] dest == [0, 1, 2, 3, 4, 6, 7, 8] - // fd: dest == [0, 1, 2, 3, 4, 6, 7, 8] - - let mut ptr_left = src_ptr; - let mut ptr_right = src_ptr.add(4); - let mut ptr_dest = dest_ptr; - - (ptr_left, ptr_right, ptr_dest) = merge_up(ptr_left, ptr_right, ptr_dest, is_less); - (ptr_left, ptr_right, ptr_dest) = merge_up(ptr_left, ptr_right, ptr_dest, is_less); - (ptr_left, ptr_right, ptr_dest) = merge_up(ptr_left, ptr_right, ptr_dest, is_less); - - finish_up(ptr_left, ptr_right, ptr_dest, is_less); - - // --- - - ptr_left = src_ptr.add(3); - ptr_right = src_ptr.add(7); - ptr_dest = dest_ptr.add(7); - - (ptr_left, ptr_right, ptr_dest) = merge_down(ptr_left, ptr_right, ptr_dest, is_less); - (ptr_left, ptr_right, ptr_dest) = merge_down(ptr_left, ptr_right, ptr_dest, is_less); - (ptr_left, ptr_right, ptr_dest) = merge_down(ptr_left, ptr_right, ptr_dest, is_less); - - finish_down(ptr_left, ptr_right, ptr_dest, is_less); - } -} - -unsafe fn parity_merge(src_ptr: *const T, dest_ptr: *mut T, len: usize, is_less: &mut F) +unsafe fn swap_next_if_less(arr_ptr: *mut T, is_less: &mut F) where F: FnMut(&T, &T) -> bool, { - // SAFETY: the caller must guarantee that `src_ptr` and `dest_ptr` are valid for writes and - // properly aligned. And they point to a contiguous owned region of memory each at least len - // elements long. Also `src_ptr` and `dest_ptr` must not alias. - // - // The caller must guarantee that the values of `dest_ptr[0..len]` have trivial - // destructors that are sound to be called on a shallow copy of T. + // SAFETY: the caller must guarantee that `arr_ptr` and `arr_ptr.add(1)` yield valid + // pointers that are properly aligned, and part of the same allocation. unsafe { - let mut block = len / 2; - - let mut ptr_left = src_ptr; - let mut ptr_right = src_ptr.add(block); - let mut ptr_data = dest_ptr; - - let mut t_ptr_left = src_ptr.add(block - 1); - let mut t_ptr_right = src_ptr.add(len - 1); - let mut t_ptr_data = dest_ptr.add(len - 1); - - block -= 1; - while block != 0 { - (ptr_left, ptr_right, ptr_data) = merge_up(ptr_left, ptr_right, ptr_data, is_less); - (t_ptr_left, t_ptr_right, t_ptr_data) = - merge_down(t_ptr_left, t_ptr_right, t_ptr_data, is_less); - - block -= 1; - } - - finish_up(ptr_left, ptr_right, ptr_data, is_less); - finish_down(t_ptr_left, t_ptr_right, t_ptr_data, is_less); + swap_if_less(arr_ptr, 0, 1, is_less); } } -// This implementation is only valid for Copy types. For these reasons: -// 1. Panic safety -// 2. Uniqueness preservation for types with interior mutability. -unsafe fn sort8(v: &mut [T], is_less: &mut F) +/// Sort 8 elements +/// +/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +/// performance impact. +#[inline(never)] +unsafe fn sort8_stable(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, { - debug_assert!(v.len() == 8); - // SAFETY: caller must ensure v is at least len 8. unsafe { - sort4(v, is_less); - sort4(&mut v[4..], is_less); + debug_assert!(v.len() == 8); let arr_ptr = v.as_mut_ptr(); - if !is_less(&*arr_ptr.add(4), &*arr_ptr.add(3)) { - return; - } - let mut swap = mem::MaybeUninit::<[T; 8]>::uninit(); - let swap_ptr = swap.as_mut_ptr() as *mut T; - - // We know the two parts v[0..4] and v[4..8] are already sorted. Merge into swap_ptr, so - // that a panic of is_less leaves arr_ptr unchanged and in a valid state, preserving all - // original elements. - parity_merge8(arr_ptr, swap_ptr, is_less); - // Once the merge is done, copy everything from swap to arr. - ptr::copy_nonoverlapping(swap_ptr, arr_ptr, 8); + // Transposition sorting-network, by only comparing and swapping adjacent wires we have a stable + // sorting-network. Sorting-networks are great at leveraging Instruction-Level-Parallelism + // (ILP), they expose multiple comparisons in straight-line code with builtin data-dependency + // parallelism and ordering per layer. This has to do 28 comparisons in contrast to the 19 + // comparisons done by an optimal size 8 unstable sorting-network. + swap_next_if_less(arr_ptr.add(0), is_less); + swap_next_if_less(arr_ptr.add(2), is_less); + swap_next_if_less(arr_ptr.add(4), is_less); + swap_next_if_less(arr_ptr.add(6), is_less); + + swap_next_if_less(arr_ptr.add(1), is_less); + swap_next_if_less(arr_ptr.add(3), is_less); + swap_next_if_less(arr_ptr.add(5), is_less); + + swap_next_if_less(arr_ptr.add(0), is_less); + swap_next_if_less(arr_ptr.add(2), is_less); + swap_next_if_less(arr_ptr.add(4), is_less); + swap_next_if_less(arr_ptr.add(6), is_less); + + swap_next_if_less(arr_ptr.add(1), is_less); + swap_next_if_less(arr_ptr.add(3), is_less); + swap_next_if_less(arr_ptr.add(5), is_less); + + swap_next_if_less(arr_ptr.add(0), is_less); + swap_next_if_less(arr_ptr.add(2), is_less); + swap_next_if_less(arr_ptr.add(4), is_less); + swap_next_if_less(arr_ptr.add(6), is_less); + + swap_next_if_less(arr_ptr.add(1), is_less); + swap_next_if_less(arr_ptr.add(3), is_less); + swap_next_if_less(arr_ptr.add(5), is_less); + + swap_next_if_less(arr_ptr.add(0), is_less); + swap_next_if_less(arr_ptr.add(2), is_less); + swap_next_if_less(arr_ptr.add(4), is_less); + swap_next_if_less(arr_ptr.add(6), is_less); + + swap_next_if_less(arr_ptr.add(1), is_less); + swap_next_if_less(arr_ptr.add(3), is_less); + swap_next_if_less(arr_ptr.add(5), is_less); } } -// This implementation is only valid for Copy types. For these reasons: -// 1. Panic safety -// 2. Uniqueness preservation for types with interior mutability. -unsafe fn sort16(v: &mut [T], is_less: &mut F) +unsafe fn sort24_stable(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, { - debug_assert!(v.len() == 16); - - // SAFETY: caller must ensure v is at least len 16. + // SAFETY: caller must ensure v is exactly len 24. unsafe { - // Sort the 4 parts of v individually. - sort4(v, is_less); - sort4(&mut v[4..], is_less); - sort4(&mut v[8..], is_less); - sort4(&mut v[12..], is_less); - - // If all 3 pairs of border elements are sorted, we know the whole 16 elements are now sorted. - // Doing this check reduces the total comparisons done on average for different input patterns. - let arr_ptr = v.as_mut_ptr(); - if !is_less(&*arr_ptr.add(4), &*arr_ptr.add(3)) - && !is_less(&*arr_ptr.add(8), &*arr_ptr.add(7)) - && !is_less(&*arr_ptr.add(12), &*arr_ptr.add(11)) - { - return; - } + debug_assert!(v.len() == 24); - let mut swap = mem::MaybeUninit::<[T; 16]>::uninit(); - let swap_ptr = swap.as_mut_ptr() as *mut T; + sort8_stable(&mut v[0..8], is_less); + sort8_stable(&mut v[8..16], is_less); + sort8_stable(&mut v[16..24], is_less); - // Merge the already sorted v[0..4] with v[4..8] into swap. - parity_merge8(arr_ptr, swap_ptr, is_less); - // Merge the already sorted v[8..12] with v[12..16] into swap. - parity_merge8(arr_ptr.add(8), swap_ptr.add(8), is_less); + // We only need place for 8 entries because we know both sides are of length 8. + let mut swap = mem::MaybeUninit::<[T; 8]>::uninit(); + let swap_ptr = swap.as_mut_ptr() as *mut T; - // v is still the same as before parity_merge8 - // swap now contains a shallow copy of v and is sorted in v[0..8] and v[8..16] - // Merge the two partitions. - // parity_merge(swap_ptr, arr_ptr, 16, is_less); + // We only need place for 8 entries because we know both sides are of length 8. + merge(&mut v[..16], 8, swap_ptr, is_less); - // FIXME discuss perf loss by promising original elements in case of panic. - ptr::copy_nonoverlapping(swap_ptr, arr_ptr, 16); - parity_merge(arr_ptr, swap_ptr, 16, is_less); - ptr::copy_nonoverlapping(swap_ptr, arr_ptr, 16); + // We only need place for 8 entries because the shorter side is length 8. + merge(&mut v[..24], 16, swap_ptr, is_less); } } From bad7e053082f8a5d1df6e5b4396ce741e8734800 Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Tue, 1 Nov 2022 13:42:19 +0100 Subject: [PATCH 4/5] Re-use TimSort analysis for small slices Add new loop based mini-merge sort for small sizes. This extends allocation free sorting for random inputs of types that qualify up to 32. --- library/alloc/src/slice.rs | 180 +++++++++++++++++-------------------- 1 file changed, 81 insertions(+), 99 deletions(-) diff --git a/library/alloc/src/slice.rs b/library/alloc/src/slice.rs index 882559b33fd05..57a3a3e96f36b 100644 --- a/library/alloc/src/slice.rs +++ b/library/alloc/src/slice.rs @@ -822,90 +822,6 @@ where merge_sort(v, &mut is_less); } -// Sort a small number of elements as fast as possible, without allocations. -#[cfg(not(no_global_oom_handling))] -fn stable_sort_small(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - let len = v.len(); - - // This implementation is really not fit for anything beyond that, and the call is probably a - // bug. - debug_assert!(len <= 40); - - if len < 2 { - return; - } - - // It's not clear that using custom code for specific sizes is worth it here. - // So we go with the simpler code. - let offset = if len <= 6 || !qualifies_for_branchless_sort::() { - 1 - } else { - // Once a certain threshold is reached, it becomes worth it to analyze the input and do - // branchless swapping for the first 5 elements. - - // SAFETY: We just checked that len >= 5 - unsafe { - let arr_ptr = v.as_mut_ptr(); - - let should_swap_0_1 = is_less(&*arr_ptr.add(1), &*arr_ptr.add(0)); - let should_swap_1_2 = is_less(&*arr_ptr.add(2), &*arr_ptr.add(1)); - let should_swap_2_3 = is_less(&*arr_ptr.add(3), &*arr_ptr.add(2)); - let should_swap_3_4 = is_less(&*arr_ptr.add(4), &*arr_ptr.add(3)); - - let swap_count = should_swap_0_1 as usize - + should_swap_1_2 as usize - + should_swap_2_3 as usize - + should_swap_3_4 as usize; - - if swap_count == 0 { - // Potentially already sorted. No need to swap, we know the first 5 elements are - // already in the right order. - 5 - } else if swap_count == 4 { - // Potentially reversed. - let mut rev_i = 4; - while rev_i < (len - 1) { - if !is_less(&*arr_ptr.add(rev_i + 1), &*arr_ptr.add(rev_i)) { - break; - } - rev_i += 1; - } - rev_i += 1; - v[..rev_i].reverse(); - insertion_sort_shift_left(v, rev_i, is_less); - return; - } else { - // Potentially random pattern. - branchless_swap(arr_ptr.add(0), arr_ptr.add(1), should_swap_0_1); - branchless_swap(arr_ptr.add(2), arr_ptr.add(3), should_swap_2_3); - - if len >= 12 { - // This aims to find a good balance between generating more code, which is bad - // for cold loops and improving hot code while not increasing mean comparison - // count too much. - sort8_stable(&mut v[4..12], is_less); - insertion_sort_shift_left(&mut v[4..], 8, is_less); - insertion_sort_shift_right(v, 4, is_less); - return; - } else { - // Complete the sort network for the first 4 elements. - swap_next_if_less(arr_ptr.add(1), is_less); - swap_next_if_less(arr_ptr.add(2), is_less); - swap_next_if_less(arr_ptr.add(0), is_less); - swap_next_if_less(arr_ptr.add(1), is_less); - - 4 - } - } - } - }; - - insertion_sort_shift_left(v, offset, is_less); -} - #[cfg(not(no_global_oom_handling))] fn merge_sort(v: &mut [T], is_less: &mut F) where @@ -918,12 +834,7 @@ where let len = v.len(); - // Slices of up to this length get sorted using insertion sort. - const MAX_NO_ALLOC_SIZE: usize = 20; - - // Short arrays get sorted in-place via insertion sort to avoid allocations. - if len <= MAX_NO_ALLOC_SIZE { - stable_sort_small(v, is_less); + if len < 2 { return; } @@ -963,6 +874,11 @@ where // return without allocating. return; } else if buf_ptr.is_null() { + // Short arrays get sorted in-place via insertion sort to avoid allocations. + if sort_small_stable(v, start, is_less) { + return; + } + // Allocate a buffer to use as scratch memory. We keep the length 0 so we can keep in it // shallow copies of the contents of `v` without risking the dtors running on copies if // `is_less` panics. When merging two sorted runs, this buffer holds a copy of the @@ -1016,11 +932,7 @@ where || (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len) || (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len)) { - if n >= 3 && runs[n - 3].len < runs[n - 1].len { - Some(n - 3) - } else { - Some(n - 2) - } + if n >= 3 && runs[n - 3].len < runs[n - 1].len { Some(n - 3) } else { Some(n - 2) } } else { None } @@ -1033,6 +945,67 @@ where } } +/// Check whether `v` applies for small sort optimization. +/// `v[start..]` is assumed already sorted. +#[cfg(not(no_global_oom_handling))] +fn sort_small_stable(v: &mut [T], start: usize, is_less: &mut F) -> bool +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + if qualifies_for_branchless_sort::() { + // Testing showed that even though this incurs more comparisons, up to size 32 (4 * 8), + // avoiding the allocation and sticking with simple code is worth it. Going further eg. 40 + // is still worth it for u64 or even types with more expensive comparisons, but risks + // incurring just too many comparisons than doing the regular TimSort. + const MAX_NO_ALLOC_SIZE: usize = 32; + if len <= MAX_NO_ALLOC_SIZE { + if len < 8 { + insertion_sort_shift_right(v, start, is_less); + return true; + } + + let mut merge_count = 0; + for chunk in v.chunks_exact_mut(8) { + // SAFETY: chunks_exact_mut promised to give us slices of len 8. + unsafe { + sort8_stable(chunk, is_less); + } + merge_count += 1; + } + + let mut swap = mem::MaybeUninit::<[T; 8]>::uninit(); + let swap_ptr = swap.as_mut_ptr() as *mut T; + + let mut i = 8; + while merge_count > 1 { + // SAFETY: We know the smaller side will be of size 8 because mid is 8. And both + // sides are non empty because of merge_count, and the right side will always be of + // size 8 and the left size of 8 or greater. Thus the smaller side will always be + // exactly 8 long, the size of swap. + unsafe { + merge(&mut v[0..(i + 8)], i, swap_ptr, is_less); + } + i += 8; + merge_count -= 1; + } + + insertion_sort_shift_left(v, i, is_less); + + return true; + } + } else { + const MAX_NO_ALLOC_SIZE: usize = 20; + if len <= MAX_NO_ALLOC_SIZE { + insertion_sort_shift_right(v, start, is_less); + return true; + } + } + + false +} + /// Takes a range as denoted by start and end, that is already sorted and extends it if necessary /// with sorts optimized for smaller ranges such as insertion sort. #[cfg(not(no_global_oom_handling))] @@ -1042,8 +1015,7 @@ where { debug_assert!(end > start); - // Testing showed that using MAX_INSERTION here yields the best performance for many types, but - // incurs more total comparisons. A balance between least comparisons and best performance, as + // This value is a balance between least comparisons and best performance, as // influenced by for example cache locality. const MIN_INSERTION_RUN: usize = 10; @@ -1115,6 +1087,7 @@ impl Drop for InsertionHole { /// Inserts `v[v.len() - 1]` into pre-sorted sequence `v[..v.len() - 1]` so that whole `v[..]` /// becomes sorted. +#[cfg(not(no_global_oom_handling))] unsafe fn insert_tail(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, @@ -1167,11 +1140,12 @@ where } } -/// Sort v assuming v[..offset] is already sorted. +/// Sort `v` assuming `v[..offset]` is already sorted. /// /// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no /// performance impact. Even improving performance in some cases. #[inline(never)] +#[cfg(not(no_global_oom_handling))] fn insertion_sort_shift_left(v: &mut [T], offset: usize, is_less: &mut F) where F: FnMut(&T, &T) -> bool, @@ -1195,11 +1169,12 @@ where } } -/// Sort v assuming v[offset..] is already sorted. +/// Sort `v` assuming `v[offset..]` is already sorted. /// /// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no /// performance impact. Even improving performance in some cases. #[inline(never)] +#[cfg(not(no_global_oom_handling))] fn insertion_sort_shift_right(v: &mut [T], offset: usize, is_less: &mut F) where F: FnMut(&T, &T) -> bool, @@ -1227,6 +1202,7 @@ where /// Inserts `v[0]` into pre-sorted sequence `v[1..]` so that whole `v[..]` becomes sorted. /// /// This is the integral subroutine of insertion sort. +#[cfg(not(no_global_oom_handling))] unsafe fn insert_head(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, @@ -1287,6 +1263,10 @@ where /// /// The two slices must be non-empty and `mid` must be in bounds. Buffer `buf` must be long enough /// to hold a copy of the shorter slice. Also, `T` must not be a zero-sized type. +/// +/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +/// performance impact. +#[inline(never)] #[cfg(not(no_global_oom_handling))] unsafe fn merge(v: &mut [T], mid: usize, buf: *mut T, is_less: &mut F) where @@ -1506,6 +1486,7 @@ where /// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no /// performance impact. #[inline(never)] +#[cfg(not(no_global_oom_handling))] unsafe fn sort8_stable(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, @@ -1559,6 +1540,7 @@ where } } +#[cfg(not(no_global_oom_handling))] unsafe fn sort24_stable(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, From 061d4e80386fd4f310d4f55f99860eb533374c01 Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Fri, 4 Nov 2022 16:54:58 +0100 Subject: [PATCH 5/5] Apply cfg(not(no_global_oom_handling) to everything --- library/alloc/src/slice.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/library/alloc/src/slice.rs b/library/alloc/src/slice.rs index 57a3a3e96f36b..a15062f4b6d37 100644 --- a/library/alloc/src/slice.rs +++ b/library/alloc/src/slice.rs @@ -1072,11 +1072,13 @@ where } // When dropped, copies from `src` into `dest`. +#[cfg(not(no_global_oom_handling))] struct InsertionHole { src: *const T, dest: *mut T, } +#[cfg(not(no_global_oom_handling))] impl Drop for InsertionHole { fn drop(&mut self) { unsafe { @@ -1398,6 +1400,7 @@ impl IsCopy for T { } #[inline] +#[cfg(not(no_global_oom_handling))] fn qualifies_for_branchless_sort() -> bool { // This is a heuristic, and as such it will guess wrong from time to time. The two parts broken // down: @@ -1416,6 +1419,7 @@ fn qualifies_for_branchless_sort() -> bool { /// Swap two values in array pointed to by a_ptr and b_ptr if b is less than a. #[inline] +#[cfg(not(no_global_oom_handling))] unsafe fn branchless_swap(a_ptr: *mut T, b_ptr: *mut T, should_swap: bool) { // This is a branchless version of swap if. // The equivalent code with a branch would be: @@ -1443,6 +1447,7 @@ unsafe fn branchless_swap(a_ptr: *mut T, b_ptr: *mut T, should_swap: bool) { /// Swap two values in array pointed to by a_ptr and b_ptr if b is less than a. #[inline] +#[cfg(not(no_global_oom_handling))] unsafe fn swap_if_less(arr_ptr: *mut T, a: usize, b: usize, is_less: &mut F) where F: FnMut(&T, &T) -> bool, @@ -1470,6 +1475,7 @@ where /// Comparing and swapping anything but adjacent elements will yield a non stable sort. /// So this must be fundamental building block for stable sorting networks. #[inline] +#[cfg(not(no_global_oom_handling))] unsafe fn swap_next_if_less(arr_ptr: *mut T, is_less: &mut F) where F: FnMut(&T, &T) -> bool,