Skip to content

Commit 59d2b10

Browse files
authored
Rollup merge of #136690 - Voultapher:use-more-explicit-and-reliable-ptr-select, r=thomcc
Use more explicit and reliable ptr select in sort impls Using `if ...` with the intent to avoid branches can be surprising to readers and carries the risk of turning into jumps/branches generated by some future compiler version, breaking crucial optimizations. This commit replaces their usage with the explicit and IR annotated `bool::select_unpredictable`.
2 parents 8227910 + 4c9b9d7 commit 59d2b10

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

library/core/src/slice/sort/shared/smallsort.rs

+13-18
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ unsafe fn swap_if_less<T, F>(v_base: *mut T, a_pos: usize, b_pos: usize, is_less
387387
where
388388
F: FnMut(&T, &T) -> bool,
389389
{
390-
// SAFETY: the caller must guarantee that `a` and `b` each added to `v_base` yield valid
390+
// SAFETY: the caller must guarantee that `a_pos` and `b_pos` each added to `v_base` yield valid
391391
// pointers into `v_base`, and are properly aligned, and part of the same allocation.
392392
unsafe {
393393
let v_a = v_base.add(a_pos);
@@ -404,16 +404,16 @@ where
404404
// The equivalent code with a branch would be:
405405
//
406406
// if should_swap {
407-
// ptr::swap(left, right, 1);
407+
// ptr::swap(v_a, v_b, 1);
408408
// }
409409

410410
// The goal is to generate cmov instructions here.
411-
let left_swap = if should_swap { v_b } else { v_a };
412-
let right_swap = if should_swap { v_a } else { v_b };
411+
let v_a_swap = should_swap.select_unpredictable(v_b, v_a);
412+
let v_b_swap = should_swap.select_unpredictable(v_a, v_b);
413413

414-
let right_swap_tmp = ManuallyDrop::new(ptr::read(right_swap));
415-
ptr::copy(left_swap, v_a, 1);
416-
ptr::copy_nonoverlapping(&*right_swap_tmp, v_b, 1);
414+
let v_b_swap_tmp = ManuallyDrop::new(ptr::read(v_b_swap));
415+
ptr::copy(v_a_swap, v_a, 1);
416+
ptr::copy_nonoverlapping(&*v_b_swap_tmp, v_b, 1);
417417
}
418418
}
419419

@@ -640,26 +640,21 @@ pub unsafe fn sort4_stable<T, F: FnMut(&T, &T) -> bool>(
640640
// 1, 1 | c b a d
641641
let c3 = is_less(&*c, &*a);
642642
let c4 = is_less(&*d, &*b);
643-
let min = select(c3, c, a);
644-
let max = select(c4, b, d);
645-
let unknown_left = select(c3, a, select(c4, c, b));
646-
let unknown_right = select(c4, d, select(c3, b, c));
643+
let min = c3.select_unpredictable(c, a);
644+
let max = c4.select_unpredictable(b, d);
645+
let unknown_left = c3.select_unpredictable(a, c4.select_unpredictable(c, b));
646+
let unknown_right = c4.select_unpredictable(d, c3.select_unpredictable(b, c));
647647

648648
// Sort the last two unknown elements.
649649
let c5 = is_less(&*unknown_right, &*unknown_left);
650-
let lo = select(c5, unknown_right, unknown_left);
651-
let hi = select(c5, unknown_left, unknown_right);
650+
let lo = c5.select_unpredictable(unknown_right, unknown_left);
651+
let hi = c5.select_unpredictable(unknown_left, unknown_right);
652652

653653
ptr::copy_nonoverlapping(min, dst, 1);
654654
ptr::copy_nonoverlapping(lo, dst.add(1), 1);
655655
ptr::copy_nonoverlapping(hi, dst.add(2), 1);
656656
ptr::copy_nonoverlapping(max, dst.add(3), 1);
657657
}
658-
659-
#[inline(always)]
660-
fn select<T>(cond: bool, if_true: *const T, if_false: *const T) -> *const T {
661-
if cond { if_true } else { if_false }
662-
}
663658
}
664659

665660
/// SAFETY: The caller MUST guarantee that `v_base` is valid for 8 reads and

0 commit comments

Comments
 (0)