Skip to content

Commit 468babb

Browse files
committed
Redo the swap code for better tail & padding handling
1 parent 4e5fec2 commit 468babb

File tree

12 files changed

+420
-134
lines changed

12 files changed

+420
-134
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

+17
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,23 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
498498
}
499499
}
500500

501+
sym::untyped_swap_nonoverlapping => {
502+
// The fallback impl uses memcpy, which leaves around allocas
503+
// that don't optimize out for certain widths, so force it to
504+
// use SSA registers instead.
505+
506+
let chunk_ty = fn_args.type_at(0);
507+
let layout = self.layout_of(chunk_ty).layout;
508+
let integer_ty = self.type_ix(layout.size().bits());
509+
let a = args[0].immediate();
510+
let b = args[1].immediate();
511+
let a_val = self.load(integer_ty, a, layout.align().abi);
512+
let b_val = self.load(integer_ty, b, layout.align().abi);
513+
self.store(b_val, a, layout.align().abi);
514+
self.store(a_val, b, layout.align().abi);
515+
return Ok(());
516+
}
517+
501518
sym::compare_bytes => {
502519
// Here we assume that the `memcmp` provided by the target is a NOP for size 0.
503520
let cmp = self.call_intrinsic("memcmp", &[

compiler/rustc_hir_analysis/src/check/intrinsic.rs

+6
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,12 @@ pub fn check_intrinsic_type(
504504
sym::typed_swap_nonoverlapping => {
505505
(1, 0, vec![Ty::new_mut_ptr(tcx, param(0)); 2], tcx.types.unit)
506506
}
507+
sym::untyped_swap_nonoverlapping => (
508+
1,
509+
0,
510+
vec![Ty::new_mut_ptr(tcx, Ty::new_maybe_uninit(tcx, param(0))); 2],
511+
tcx.types.unit,
512+
),
507513

508514
sym::discriminant_value => {
509515
let assoc_items = tcx.associated_item_def_ids(

compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -2142,6 +2142,7 @@ symbols! {
21422142
unstable location; did you mean to load this crate \
21432143
from crates.io via `Cargo.toml` instead?",
21442144
untagged_unions,
2145+
untyped_swap_nonoverlapping,
21452146
unused_imports,
21462147
unwind,
21472148
unwind_attributes,

library/core/src/intrinsics/mod.rs

+32-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767
use crate::marker::{DiscriminantKind, Tuple};
6868
use crate::mem::SizedTypeProperties;
69-
use crate::{ptr, ub_checks};
69+
use crate::{mem, ptr, ub_checks};
7070

7171
pub mod fallback;
7272
pub mod mir;
@@ -4003,7 +4003,37 @@ pub use typed_swap as typed_swap_nonoverlapping;
40034003
pub const unsafe fn typed_swap_nonoverlapping<T>(x: *mut T, y: *mut T) {
40044004
// SAFETY: The caller provided single non-overlapping items behind
40054005
// pointers, so swapping them with `count: 1` is fine.
4006-
unsafe { ptr::swap_nonoverlapping(x, y, 1) };
4006+
unsafe { crate::swapping::swap_nonoverlapping(x, y, 1) };
4007+
}
4008+
4009+
/// Swaps the `N` untyped & non-overlapping bytes behind the two pointers.
4010+
///
4011+
/// Split out from `typed_swap` for the internal swaps in `swap_nonoverlapping`
4012+
/// which would otherwise cause cycles between the fallback implementations on
4013+
/// backends where neither is overridden.
4014+
///
4015+
/// # Safety
4016+
///
4017+
/// `x` and `y` are readable and writable as `MaybeUninit<C>` and non-overlapping.
4018+
#[inline]
4019+
#[rustc_nounwind]
4020+
#[cfg_attr(not(bootstrap), rustc_intrinsic)]
4021+
#[miri::intrinsic_fallback_is_spec]
4022+
#[rustc_const_stable_indirect]
4023+
pub const unsafe fn untyped_swap_nonoverlapping<C>(
4024+
x: *mut mem::MaybeUninit<C>,
4025+
y: *mut mem::MaybeUninit<C>,
4026+
) {
4027+
// This intentionally uses untyped memory copies, not reads/writes,
4028+
// to avoid any risk of losing padding in things like (u16, u8).
4029+
let mut temp = mem::MaybeUninit::<C>::uninit();
4030+
// SAFETY: Caller promised that x and y are non-overlapping & read/writeable,
4031+
// and our fresh local is always disjoint from anything otherwise readable.
4032+
unsafe {
4033+
(&raw mut temp).copy_from_nonoverlapping(x, 1);
4034+
x.copy_from_nonoverlapping(y, 1);
4035+
y.copy_from_nonoverlapping(&raw const temp, 1);
4036+
}
40074037
}
40084038

40094039
/// Returns whether we should perform some UB-checking at runtime. This eventually evaluates to

library/core/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ pub mod alloc;
376376
// note: does not need to be public
377377
mod bool;
378378
mod escape;
379+
pub(crate) mod swapping;
379380
mod tuple;
380381
mod unit;
381382

library/core/src/ptr/mod.rs

+1-78
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,6 @@
395395
#![allow(clippy::not_unsafe_ptr_arg_deref)]
396396

397397
use crate::cmp::Ordering;
398-
use crate::intrinsics::const_eval_select;
399398
use crate::marker::FnPtr;
400399
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
401400
use crate::{fmt, hash, intrinsics, ub_checks};
@@ -1092,84 +1091,8 @@ pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
10921091
}
10931092
);
10941093

1095-
const_eval_select!(
1096-
@capture[T] { x: *mut T, y: *mut T, count: usize }:
1097-
if const {
1098-
// At compile-time we want to always copy this in chunks of `T`, to ensure that if there
1099-
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
1100-
// of a pointer (which would not work).
11011094
// SAFETY: Same preconditions as this function
1102-
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
1103-
} else {
1104-
macro_rules! attempt_swap_as_chunks {
1105-
($ChunkTy:ty) => {
1106-
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
1107-
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
1108-
{
1109-
let x: *mut $ChunkTy = x.cast();
1110-
let y: *mut $ChunkTy = y.cast();
1111-
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
1112-
// SAFETY: these are the same bytes that the caller promised were
1113-
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
1114-
// The `if` condition above ensures that we're not violating
1115-
// alignment requirements, and that the division is exact so
1116-
// that we don't lose any bytes off the end.
1117-
return unsafe { swap_nonoverlapping_simple_untyped(x, y, count) };
1118-
}
1119-
};
1120-
}
1121-
1122-
// Split up the slice into small power-of-two-sized chunks that LLVM is able
1123-
// to vectorize (unless it's a special type with more-than-pointer alignment,
1124-
// because we don't want to pessimize things like slices of SIMD vectors.)
1125-
if mem::align_of::<T>() <= mem::size_of::<usize>()
1126-
&& (!mem::size_of::<T>().is_power_of_two()
1127-
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
1128-
{
1129-
attempt_swap_as_chunks!(usize);
1130-
attempt_swap_as_chunks!(u8);
1131-
}
1132-
1133-
// SAFETY: Same preconditions as this function
1134-
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
1135-
}
1136-
)
1137-
}
1138-
1139-
/// Same behavior and safety conditions as [`swap_nonoverlapping`]
1140-
///
1141-
/// LLVM can vectorize this (at least it can for the power-of-two-sized types
1142-
/// `swap_nonoverlapping` tries to use) so no need to manually SIMD it.
1143-
#[inline]
1144-
const unsafe fn swap_nonoverlapping_simple_untyped<T>(x: *mut T, y: *mut T, count: usize) {
1145-
let x = x.cast::<MaybeUninit<T>>();
1146-
let y = y.cast::<MaybeUninit<T>>();
1147-
let mut i = 0;
1148-
while i < count {
1149-
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
1150-
let x = unsafe { x.add(i) };
1151-
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
1152-
// and it's distinct from `x` since the ranges are non-overlapping
1153-
let y = unsafe { y.add(i) };
1154-
1155-
// If we end up here, it's because we're using a simple type -- like
1156-
// a small power-of-two-sized thing -- or a special type with particularly
1157-
// large alignment, particularly SIMD types.
1158-
// Thus, we're fine just reading-and-writing it, as either it's small
1159-
// and that works well anyway or it's special and the type's author
1160-
// presumably wanted things to be done in the larger chunk.
1161-
1162-
// SAFETY: we're only ever given pointers that are valid to read/write,
1163-
// including being aligned, and nothing here panics so it's drop-safe.
1164-
unsafe {
1165-
let a: MaybeUninit<T> = read(x);
1166-
let b: MaybeUninit<T> = read(y);
1167-
write(x, b);
1168-
write(y, a);
1169-
}
1170-
1171-
i += 1;
1172-
}
1095+
unsafe { crate::swapping::swap_nonoverlapping(x, y, count) }
11731096
}
11741097

11751098
/// Moves `src` into the pointed `dst`, returning the previous `dst` value.

library/core/src/swapping.rs

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
use crate::{hint, intrinsics, mem, ptr};
2+
3+
//#[rustc_const_stable_indirect]
4+
//#[rustc_allow_const_fn_unstable(const_eval_select)]
5+
#[rustc_const_unstable(feature = "const_swap_nonoverlapping", issue = "133668")]
6+
#[inline]
7+
pub(crate) const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
8+
intrinsics::const_eval_select!(
9+
@capture[T] { x: *mut T, y: *mut T, count: usize }:
10+
if const {
11+
// At compile-time we want to always copy this in chunks of `T`, to ensure that if there
12+
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
13+
// of a pointer (which would not work).
14+
// SAFETY: Same preconditions as this function
15+
unsafe { swap_nonoverlapping_const(x, y, count) }
16+
} else {
17+
// At runtime we want to make sure not to swap byte-for-byte for types like [u8; 15],
18+
// and swapping as `MaybeUninit<T>` doesn't actually work as untyped for things like
19+
// T = (u16, u8), so we type-erase to raw bytes and swap that way.
20+
// SAFETY: Same preconditions as this function
21+
unsafe { swap_nonoverlapping_runtime(x, y, count) }
22+
}
23+
)
24+
}
25+
26+
/// Same behavior and safety conditions as [`swap_nonoverlapping`]
27+
#[rustc_const_stable_indirect]
28+
#[inline]
29+
const unsafe fn swap_nonoverlapping_const<T>(x: *mut T, y: *mut T, count: usize) {
30+
let x = x.cast::<mem::MaybeUninit<T>>();
31+
let y = y.cast::<mem::MaybeUninit<T>>();
32+
let mut i = 0;
33+
while i < count {
34+
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
35+
// and because the two input ranges are non-overlapping and read/writeable,
36+
// these individual items inside them are too.
37+
unsafe {
38+
intrinsics::untyped_swap_nonoverlapping::<T>(x.add(i), y.add(i));
39+
}
40+
41+
i += 1;
42+
}
43+
}
44+
45+
// Scale the monomorphizations with the size of the machine, roughly.
46+
const MAX_ALIGN: usize = align_of::<usize>().pow(2);
47+
48+
/// Same behavior and safety conditions as [`swap_nonoverlapping`]
49+
#[inline]
50+
unsafe fn swap_nonoverlapping_runtime<T>(x: *mut T, y: *mut T, count: usize) {
51+
let bytes = {
52+
let slice = ptr::slice_from_raw_parts(x, count);
53+
unsafe { mem::size_of_val_raw(slice) }
54+
};
55+
56+
// Generating *untyped* loops for every type is silly, so we polymorphize away
57+
// the actual type, but we want to take advantage of alignment if possible,
58+
// so monomorphize for a restricted set of possible alignments.
59+
macro_rules! delegate_by_alignment {
60+
($($p:pat => $align:expr,)+) => {{
61+
#![allow(unreachable_patterns)]
62+
match const { align_of::<T>() } {
63+
$(
64+
$p => {
65+
swap_nonoverlapping_bytes::<$align>(x.cast(), y.cast(), bytes);
66+
}
67+
)+
68+
}
69+
}};
70+
}
71+
72+
// SAFETY:
73+
unsafe {
74+
delegate_by_alignment! {
75+
MAX_ALIGN.. => MAX_ALIGN,
76+
64.. => 64,
77+
32.. => 32,
78+
16.. => 16,
79+
8.. => 8,
80+
4.. => 4,
81+
2.. => 2,
82+
_ => 1,
83+
}
84+
}
85+
}
86+
87+
/// # Safety:
88+
/// - `x` and `y` must be aligned to `ALIGN`
89+
/// - `bytes` must be a multiple of `ALIGN`
90+
/// - They must be readable, writable, and non-overlapping for `bytes` bytes
91+
#[inline]
92+
unsafe fn swap_nonoverlapping_bytes<const ALIGN: usize>(
93+
x: *mut mem::MaybeUninit<u8>,
94+
y: *mut mem::MaybeUninit<u8>,
95+
bytes: usize,
96+
) {
97+
// SAFETY: Two legal non-overlapping regions can't be bigger than this.
98+
// (And they couldn't have made allocations any bigger either anyway.)
99+
// FIXME: Would be nice to have a type for this instead of the assume.
100+
unsafe { hint::assert_unchecked(bytes < isize::MAX as usize) };
101+
102+
let mut i = 0;
103+
macro_rules! swap_next_n {
104+
($n:expr) => {{
105+
let x: *mut mem::MaybeUninit<[u8; $n]> = x.add(i).cast();
106+
let y: *mut mem::MaybeUninit<[u8; $n]> = y.add(i).cast();
107+
swap_nonoverlapping_aligned_chunk::<ALIGN, [u8; $n]>(
108+
x.as_mut_unchecked(),
109+
y.as_mut_unchecked(),
110+
);
111+
i += $n;
112+
}};
113+
}
114+
115+
while bytes - i >= MAX_ALIGN {
116+
unsafe {
117+
swap_next_n!(MAX_ALIGN);
118+
}
119+
}
120+
121+
macro_rules! handle_tail {
122+
($($n:literal)+) => {$(
123+
if const { $n % ALIGN == 0 } {
124+
// Checking this way simplifies the block end to just add+test,
125+
// rather than needing extra math before the check.
126+
if (bytes & $n) != 0 {
127+
unsafe {
128+
swap_next_n!($n);
129+
}
130+
}
131+
}
132+
)+};
133+
}
134+
const { assert!(MAX_ALIGN <= 64) };
135+
handle_tail!(32 16 8 4 2 1);
136+
137+
debug_assert_eq!(i, bytes);
138+
}
139+
140+
// Don't let MIR inline this, because we really want it to keep its noalias metadata
141+
#[rustc_no_mir_inline]
142+
#[inline]
143+
unsafe fn swap_nonoverlapping_aligned_chunk<const ALIGN: usize, C: Copy>(
144+
x: &mut mem::MaybeUninit<C>,
145+
y: &mut mem::MaybeUninit<C>,
146+
) {
147+
assert!(size_of::<C>() % ALIGN == 0);
148+
149+
let x = ptr::from_mut(x);
150+
let y = ptr::from_mut(y);
151+
152+
unsafe {
153+
hint::assert_unchecked(x.is_aligned_to(ALIGN));
154+
hint::assert_unchecked(y.is_aligned_to(ALIGN));
155+
}
156+
157+
unsafe {
158+
intrinsics::untyped_swap_nonoverlapping::<C>(x, y);
159+
}
160+
}

library/core/tests/ptr.rs

+29
Original file line numberDiff line numberDiff line change
@@ -992,3 +992,32 @@ fn test_ptr_metadata_in_const() {
992992
assert_eq!(SLICE_META, 3);
993993
assert_eq!(DYN_META.size_of(), 42);
994994
}
995+
996+
// See <https://github.com/rust-lang/rust/issues/134713>
997+
#[test]
998+
fn test_ptr_swap_nonoverlapping_swaps_padding() {
999+
#[repr(C)]
1000+
struct Foo(usize, u8);
1001+
1002+
let buf1: [usize; 2] = [1000, 2000];
1003+
let buf2: [usize; 2] = [3000, 4000];
1004+
1005+
// Foo and [usize; 2] have the same size and alignment,
1006+
// so swap_nonoverlapping should treat them the same
1007+
assert_eq!(size_of::<Foo>(), size_of::<[usize; 2]>());
1008+
assert_eq!(align_of::<Foo>(), align_of::<[usize; 2]>());
1009+
1010+
let mut b1 = buf1;
1011+
let mut b2 = buf2;
1012+
// Safety: b1 and b2 are distinct local variables,
1013+
// with the same size and alignment as Foo.
1014+
unsafe {
1015+
std::ptr::swap_nonoverlapping(
1016+
b1.as_mut_ptr().cast::<Foo>(),
1017+
b2.as_mut_ptr().cast::<Foo>(),
1018+
1,
1019+
);
1020+
}
1021+
assert_eq!(b1, buf2);
1022+
assert_eq!(b2, buf1);
1023+
}

0 commit comments

Comments
 (0)