Skip to content

Commit 6964320

Browse files
committed
Ensure swap_nonoverlapping is really always untyped
1 parent a18bd8a commit 6964320

File tree

6 files changed

+182
-79
lines changed

6 files changed

+182
-79
lines changed

library/core/src/ptr/mod.rs

+75-42
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ use crate::cmp::Ordering;
398398
use crate::intrinsics::const_eval_select;
399399
use crate::marker::FnPtr;
400400
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
401+
use crate::num::NonZero;
401402
use crate::{fmt, hash, intrinsics, ub_checks};
402403

403404
mod alignment;
@@ -1094,49 +1095,22 @@ pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
10941095
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
10951096
// of a pointer (which would not work).
10961097
// SAFETY: Same preconditions as this function
1097-
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
1098+
unsafe { swap_nonoverlapping_const(x, y, count) }
10981099
} else {
1099-
macro_rules! attempt_swap_as_chunks {
1100-
($ChunkTy:ty) => {
1101-
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
1102-
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
1103-
{
1104-
let x: *mut $ChunkTy = x.cast();
1105-
let y: *mut $ChunkTy = y.cast();
1106-
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
1107-
// SAFETY: these are the same bytes that the caller promised were
1108-
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
1109-
// The `if` condition above ensures that we're not violating
1110-
// alignment requirements, and that the division is exact so
1111-
// that we don't lose any bytes off the end.
1112-
return unsafe { swap_nonoverlapping_simple_untyped(x, y, count) };
1113-
}
1114-
};
1100+
// SAFETY: To exist as a memory range its size in bytes can't overflow.
1101+
let bytes = unsafe { size_of::<T>().unchecked_mul(count) };
1102+
if let Some(bytes) = NonZero::new(bytes) {
1103+
// SAFETY: These are the same ranges, just expressed in a different
1104+
// type, so they're still non-overlapping.
1105+
unsafe { swap_nonoverlapping_bytes(x.cast(), y.cast(), bytes) };
11151106
}
1116-
1117-
// Split up the slice into small power-of-two-sized chunks that LLVM is able
1118-
// to vectorize (unless it's a special type with more-than-pointer alignment,
1119-
// because we don't want to pessimize things like slices of SIMD vectors.)
1120-
if mem::align_of::<T>() <= mem::size_of::<usize>()
1121-
&& (!mem::size_of::<T>().is_power_of_two()
1122-
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
1123-
{
1124-
attempt_swap_as_chunks!(usize);
1125-
attempt_swap_as_chunks!(u8);
1126-
}
1127-
1128-
// SAFETY: Same preconditions as this function
1129-
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
11301107
}
11311108
)
11321109
}
11331110

11341111
/// Same behavior and safety conditions as [`swap_nonoverlapping`]
1135-
///
1136-
/// LLVM can vectorize this (at least it can for the power-of-two-sized types
1137-
/// `swap_nonoverlapping` tries to use) so no need to manually SIMD it.
11381112
#[inline]
1139-
const unsafe fn swap_nonoverlapping_simple_untyped<T>(x: *mut T, y: *mut T, count: usize) {
1113+
const unsafe fn swap_nonoverlapping_const<T>(x: *mut T, y: *mut T, count: usize) {
11401114
let x = x.cast::<MaybeUninit<T>>();
11411115
let y = y.cast::<MaybeUninit<T>>();
11421116
let mut i = 0;
@@ -1147,13 +1121,6 @@ const unsafe fn swap_nonoverlapping_simple_untyped<T>(x: *mut T, y: *mut T, coun
11471121
// and it's distinct from `x` since the ranges are non-overlapping
11481122
let y = unsafe { y.add(i) };
11491123

1150-
// If we end up here, it's because we're using a simple type -- like
1151-
// a small power-of-two-sized thing -- or a special type with particularly
1152-
// large alignment, particularly SIMD types.
1153-
// Thus, we're fine just reading-and-writing it, as either it's small
1154-
// and that works well anyway or it's special and the type's author
1155-
// presumably wanted things to be done in the larger chunk.
1156-
11571124
// SAFETY: we're only ever given pointers that are valid to read/write,
11581125
// including being aligned, and nothing here panics so it's drop-safe.
11591126
unsafe {
@@ -1167,6 +1134,72 @@ const unsafe fn swap_nonoverlapping_simple_untyped<T>(x: *mut T, y: *mut T, coun
11671134
}
11681135
}
11691136

1137+
// Don't let MIR inline this, because we really want it to keep its noalias metadata
1138+
#[rustc_no_mir_inline]
1139+
#[inline]
1140+
fn swap_chunk<const N: usize>(x: &mut MaybeUninit<[u8; N]>, y: &mut MaybeUninit<[u8; N]>) {
1141+
let a = *x;
1142+
let b = *y;
1143+
*x = b;
1144+
*y = a;
1145+
}
1146+
1147+
#[inline]
1148+
unsafe fn swap_nonoverlapping_bytes(x: *mut u8, y: *mut u8, bytes: NonZero<usize>) {
1149+
// Same as `swap_nonoverlapping::<[u8; N]>`.
1150+
#[inline]
1151+
unsafe fn swap_nonoverlapping_chunks<const N: usize>(
1152+
x: *mut MaybeUninit<[u8; N]>,
1153+
y: *mut MaybeUninit<[u8; N]>,
1154+
chunks: NonZero<usize>,
1155+
) {
1156+
let chunks = chunks.get();
1157+
for i in 0..chunks {
1158+
// SAFETY: i is in [0, chunks) so the adds and dereferences are in-bounds.
1159+
unsafe { swap_chunk(&mut *x.add(i), &mut *y.add(i)) };
1160+
}
1161+
}
1162+
1163+
// Same as `swap_nonoverlapping_bytes`, but accepts at most 1+2+4=7 bytes
1164+
#[inline]
1165+
unsafe fn swap_nonoverlapping_short(x: *mut u8, y: *mut u8, bytes: NonZero<usize>) {
1166+
let bytes = bytes.get();
1167+
let mut i = 0;
1168+
macro_rules! swap_prefix {
1169+
($($n:literal)+) => {$(
1170+
if (bytes & $n) != 0 {
1171+
// SAFETY: `i` can only have the same bits set as those in bytes,
1172+
// so these `add`s are in-bounds of `bytes`. But the bit for
1173+
// `$n` hasn't been set yet, so the `$n` bytes that `swap_chunk`
1174+
// will read and write are within the usable range.
1175+
unsafe { swap_chunk::<$n>(&mut*x.add(i).cast(), &mut*y.add(i).cast()) };
1176+
i |= $n;
1177+
}
1178+
)+};
1179+
}
1180+
swap_prefix!(4 2 1);
1181+
debug_assert_eq!(i, bytes);
1182+
}
1183+
1184+
const CHUNK_SIZE: usize = size_of::<*const ()>();
1185+
let bytes = bytes.get();
1186+
1187+
let chunks = bytes / CHUNK_SIZE;
1188+
let tail = bytes % CHUNK_SIZE;
1189+
if let Some(chunks) = NonZero::new(chunks) {
1190+
// SAFETY: this is bytes/CHUNK_SIZE*CHUNK_SIZE bytes, which is <= bytes,
1191+
// so it's within the range of our non-overlapping bytes.
1192+
unsafe { swap_nonoverlapping_chunks::<CHUNK_SIZE>(x.cast(), y.cast(), chunks) };
1193+
}
1194+
if let Some(tail) = NonZero::new(tail) {
1195+
const { assert!(CHUNK_SIZE <= 8) };
1196+
let delta = chunks * CHUNK_SIZE;
1197+
// SAFETY: the tail length is below CHUNK SIZE because of the remainder,
1198+
// and CHUNK_SIZE is at most 8 by the const assert, so tail <= 7
1199+
unsafe { swap_nonoverlapping_short(x.add(delta), y.add(delta), tail) };
1200+
}
1201+
}
1202+
11701203
/// Moves `src` into the pointed `dst`, returning the previous `dst` value.
11711204
///
11721205
/// Neither value is dropped.

tests/assembly/x86_64-typed-swap.rs

+28
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,31 @@ pub fn swap_simd(x: &mut __m128, y: &mut __m128) {
5151
// CHECK-NEXT: retq
5252
swap(x, y)
5353
}
54+
55+
// CHECK-LABEL: swap_string:
56+
#[no_mangle]
57+
pub fn swap_string(x: &mut String, y: &mut String) {
58+
// CHECK-NOT: mov
59+
// CHECK-COUNT-4: movups
60+
// CHECK-NOT: mov
61+
// CHECK-COUNT-4: movq
62+
// CHECK-NOT: mov
63+
swap(x, y)
64+
}
65+
66+
// CHECK-LABEL: swap_44_bytes:
67+
#[no_mangle]
68+
pub fn swap_44_bytes(x: &mut [u8; 44], y: &mut [u8; 44]) {
69+
// Ensure we do better than a long run of byte copies,
70+
// see <https://github.com/rust-lang/rust/issues/134946>
71+
72+
// CHECK-NOT: movb
73+
// CHECK-COUNT-8: movups{{.+}}xmm
74+
// CHECK-NOT: movb
75+
// CHECK-COUNT-4: movq
76+
// CHECK-NOT: movb
77+
// CHECK-COUNT-4: movl
78+
// CHECK-NOT: movb
79+
// CHECK: retq
80+
swap(x, y)
81+
}

tests/codegen/simd/swap-simd-types.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ pub fn swap_single_m256(x: &mut __m256, y: &mut __m256) {
2323
#[no_mangle]
2424
pub fn swap_m256_slice(x: &mut [__m256], y: &mut [__m256]) {
2525
// CHECK-NOT: alloca
26-
// CHECK: load <8 x float>{{.+}}align 32
27-
// CHECK: store <8 x float>{{.+}}align 32
26+
// CHECK-COUNT-2: load <4 x i64>{{.+}}align 32
27+
// CHECK-COUNT-2: store <4 x i64>{{.+}}align 32
2828
if x.len() == y.len() {
2929
x.swap_with_slice(y);
3030
}
@@ -34,7 +34,7 @@ pub fn swap_m256_slice(x: &mut [__m256], y: &mut [__m256]) {
3434
#[no_mangle]
3535
pub fn swap_bytes32(x: &mut [u8; 32], y: &mut [u8; 32]) {
3636
// CHECK-NOT: alloca
37-
// CHECK: load <32 x i8>{{.+}}align 1
38-
// CHECK: store <32 x i8>{{.+}}align 1
37+
// CHECK-COUNT-2: load <4 x i64>{{.+}}align 1
38+
// CHECK-COUNT-2: store <4 x i64>{{.+}}align 1
3939
swap(x, y)
4040
}

tests/codegen/swap-large-types.rs

+37-18
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ pub fn swap_std(x: &mut KeccakBuffer, y: &mut KeccakBuffer) {
3838
swap(x, y)
3939
}
4040

41-
// Verify that types with usize alignment are swapped via vectored usizes,
42-
// not falling back to byte-level code.
43-
4441
// CHECK-LABEL: @swap_slice
4542
#[no_mangle]
4643
pub fn swap_slice(x: &mut [KeccakBuffer], y: &mut [KeccakBuffer]) {
@@ -52,39 +49,61 @@ pub fn swap_slice(x: &mut [KeccakBuffer], y: &mut [KeccakBuffer]) {
5249
}
5350
}
5451

55-
// But for a large align-1 type, vectorized byte copying is what we want.
56-
5752
type OneKilobyteBuffer = [u8; 1024];
5853

5954
// CHECK-LABEL: @swap_1kb_slices
6055
#[no_mangle]
6156
pub fn swap_1kb_slices(x: &mut [OneKilobyteBuffer], y: &mut [OneKilobyteBuffer]) {
6257
// CHECK-NOT: alloca
63-
// CHECK: load <{{[0-9]+}} x i8>
64-
// CHECK: store <{{[0-9]+}} x i8>
58+
59+
// CHECK-NOT: load i32
60+
// CHECK-NOT: store i32
61+
// CHECK-NOT: load i16
62+
// CHECK-NOT: store i16
63+
// CHECK-NOT: load i8
64+
// CHECK-NOT: store i8
65+
66+
// CHECK: load <{{[0-9]+}} x i64>{{.+}}align 1,
67+
// CHECK: store <{{[0-9]+}} x i64>{{.+}}align 1,
68+
69+
// CHECK-NOT: load i32
70+
// CHECK-NOT: store i32
71+
// CHECK-NOT: load i16
72+
// CHECK-NOT: store i16
73+
// CHECK-NOT: load i8
74+
// CHECK-NOT: store i8
75+
6576
if x.len() == y.len() {
6677
x.swap_with_slice(y);
6778
}
6879
}
6980

70-
// This verifies that the 2×read + 2×write optimizes to just 3 memcpys
71-
// for an unusual type like this. It's not clear whether we should do anything
72-
// smarter in Rust for these, so for now it's fine to leave these up to the backend.
73-
// That's not as bad as it might seem, as for example, LLVM will lower the
74-
// memcpys below to VMOVAPS on YMMs if one enables the AVX target feature.
75-
// Eventually we'll be able to pass `align_of::<T>` to a const generic and
76-
// thus pick a smarter chunk size ourselves without huge code duplication.
77-
7881
#[repr(align(64))]
7982
pub struct BigButHighlyAligned([u8; 64 * 3]);
8083

8184
// CHECK-LABEL: @swap_big_aligned
8285
#[no_mangle]
8386
pub fn swap_big_aligned(x: &mut BigButHighlyAligned, y: &mut BigButHighlyAligned) {
8487
// CHECK-NOT: call void @llvm.memcpy
85-
// CHECK: call void @llvm.memcpy.{{.+}}(ptr noundef nonnull align 64 dereferenceable(192)
86-
// CHECK: call void @llvm.memcpy.{{.+}}(ptr noundef nonnull align 64 dereferenceable(192)
87-
// CHECK: call void @llvm.memcpy.{{.+}}(ptr noundef nonnull align 64 dereferenceable(192)
88+
// CHECK-NOT: load i32
89+
// CHECK-NOT: store i32
90+
// CHECK-NOT: load i16
91+
// CHECK-NOT: store i16
92+
// CHECK-NOT: load i8
93+
// CHECK-NOT: store i8
94+
95+
// CHECK-COUNT-2: load <{{[0-9]+}} x i64>{{.+}}align 64,
96+
// CHECK-COUNT-2: store <{{[0-9]+}} x i64>{{.+}}align 64,
97+
98+
// CHECK-COUNT-2: load <{{[0-9]+}} x i64>{{.+}}align 32,
99+
// CHECK-COUNT-2: store <{{[0-9]+}} x i64>{{.+}}align 32,
100+
101+
// CHECK-NOT: load i32
102+
// CHECK-NOT: store i32
103+
// CHECK-NOT: load i16
104+
// CHECK-NOT: store i16
105+
// CHECK-NOT: load i8
106+
// CHECK-NOT: store i8
88107
// CHECK-NOT: call void @llvm.memcpy
89108
swap(x, y)
90109
}

tests/codegen/swap-small-types.rs

+37-14
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,19 @@ pub fn swap_rgb48_manually(x: &mut RGB48, y: &mut RGB48) {
2727
pub fn swap_rgb48(x: &mut RGB48, y: &mut RGB48) {
2828
// CHECK-NOT: alloca
2929

30-
// Whether `i8` is the best for this is unclear, but
31-
// might as well record what's actually happening right now.
32-
33-
// CHECK: load i8
34-
// CHECK: load i8
35-
// CHECK: store i8
36-
// CHECK: store i8
30+
// Swapping `i48` might be cleaner in LLVM-IR here, but `i32`+`i16` isn't bad,
31+
// and is closer to the assembly it generates anyway.
32+
33+
// CHECK-NOT: load{{ }}
34+
// CHECK: load i32{{.+}}align 2
35+
// CHECK-NEXT: load i32{{.+}}align 2
36+
// CHECK-NEXT: store i32{{.+}}align 2
37+
// CHECK-NEXT: store i32{{.+}}align 2
38+
// CHECK: load i16{{.+}}align 2
39+
// CHECK-NEXT: load i16{{.+}}align 2
40+
// CHECK-NEXT: store i16{{.+}}align 2
41+
// CHECK-NEXT: store i16{{.+}}align 2
42+
// CHECK-NOT: store{{ }}
3743
swap(x, y)
3844
}
3945

@@ -76,30 +82,47 @@ pub fn swap_slices<'a>(x: &mut &'a [u32], y: &mut &'a [u32]) {
7682
swap(x, y)
7783
}
7884

79-
// LLVM doesn't vectorize a loop over 3-byte elements,
80-
// so we chunk it down to bytes and loop over those instead.
8185
type RGB24 = [u8; 3];
8286

8387
// CHECK-LABEL: @swap_rgb24_slices
8488
#[no_mangle]
8589
pub fn swap_rgb24_slices(x: &mut [RGB24], y: &mut [RGB24]) {
8690
// CHECK-NOT: alloca
87-
// CHECK: load <{{[0-9]+}} x i8>
88-
// CHECK: store <{{[0-9]+}} x i8>
91+
92+
// CHECK: load <{{[0-9]+}} x i64>
93+
// CHECK: store <{{[0-9]+}} x i64>
94+
95+
// CHECK-COUNT-2: load i32
96+
// CHECK-COUNT-2: store i32
97+
// CHECK-COUNT-2: load i16
98+
// CHECK-COUNT-2: store i16
99+
// CHECK-COUNT-2: load i8
100+
// CHECK-COUNT-2: store i8
89101
if x.len() == y.len() {
90102
x.swap_with_slice(y);
91103
}
92104
}
93105

94-
// This one has a power-of-two size, so we iterate over it directly
95106
type RGBA32 = [u8; 4];
96107

97108
// CHECK-LABEL: @swap_rgba32_slices
98109
#[no_mangle]
99110
pub fn swap_rgba32_slices(x: &mut [RGBA32], y: &mut [RGBA32]) {
100111
// CHECK-NOT: alloca
101-
// CHECK: load <{{[0-9]+}} x i32>
102-
// CHECK: store <{{[0-9]+}} x i32>
112+
113+
// Because the size in bytes in a multiple of 4, we can skip the smallest sizes.
114+
115+
// CHECK: load <{{[0-9]+}} x i64>
116+
// CHECK: store <{{[0-9]+}} x i64>
117+
118+
// CHECK-COUNT-2: load i32
119+
// CHECK-COUNT-2: store i32
120+
121+
// CHECK-NOT: load i16
122+
// CHECK-NOT: store i16
123+
// CHECK-NOT: load i8
124+
// CHECK-NOT: store i8
125+
103126
if x.len() == y.len() {
104127
x.swap_with_slice(y);
105128
}

tests/ui/consts/missing_span_in_backtrace.stderr

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ error[E0080]: evaluation of constant value failed
55
|
66
note: inside `std::ptr::read::<MaybeUninit<MaybeUninit<u8>>>`
77
--> $SRC_DIR/core/src/ptr/mod.rs:LL:COL
8-
note: inside `std::ptr::swap_nonoverlapping_simple_untyped::<MaybeUninit<u8>>`
8+
note: inside `std::ptr::swap_nonoverlapping_const::<MaybeUninit<u8>>`
99
--> $SRC_DIR/core/src/ptr/mod.rs:LL:COL
1010
note: inside `swap_nonoverlapping::compiletime::<MaybeUninit<u8>>`
1111
--> $SRC_DIR/core/src/ptr/mod.rs:LL:COL

0 commit comments

Comments
 (0)