Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure swap_nonoverlapping is really always untyped #137412

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 88 additions & 48 deletions library/core/src/ptr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ use crate::cmp::Ordering;
use crate::intrinsics::const_eval_select;
use crate::marker::FnPtr;
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
use crate::num::NonZero;
use crate::{fmt, hash, intrinsics, ub_checks};

mod alignment;
Expand Down Expand Up @@ -1094,51 +1095,25 @@ pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
// of a pointer (which would not work).
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
unsafe { swap_nonoverlapping_const(x, y, count) }
} else {
macro_rules! attempt_swap_as_chunks {
($ChunkTy:ty) => {
if align_of::<T>() >= align_of::<$ChunkTy>()
&& size_of::<T>() % size_of::<$ChunkTy>() == 0
{
let x: *mut $ChunkTy = x.cast();
let y: *mut $ChunkTy = y.cast();
let count = count * (size_of::<T>() / size_of::<$ChunkTy>());
// SAFETY: these are the same bytes that the caller promised were
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
// The `if` condition above ensures that we're not violating
// alignment requirements, and that the division is exact so
// that we don't lose any bytes off the end.
return unsafe { swap_nonoverlapping_simple_untyped(x, y, count) };
}
};
// Going though a slice here helps codegen know the size fits in `isize`
let slice = slice_from_raw_parts_mut(x, count);
// SAFETY: This is all readable from the pointer, meaning it's one
// allocated object, and thus cannot be more than isize::MAX bytes.
let bytes = unsafe { mem::size_of_val_raw::<[T]>(slice) };
if let Some(bytes) = NonZero::new(bytes) {
// SAFETY: These are the same ranges, just expressed in a different
// type, so they're still non-overlapping.
unsafe { swap_nonoverlapping_bytes(x.cast(), y.cast(), bytes) };
}

// Split up the slice into small power-of-two-sized chunks that LLVM is able
// to vectorize (unless it's a special type with more-than-pointer alignment,
// because we don't want to pessimize things like slices of SIMD vectors.)
if align_of::<T>() <= size_of::<usize>()
&& (!size_of::<T>().is_power_of_two()
|| size_of::<T>() > size_of::<usize>() * 2)
{
attempt_swap_as_chunks!(usize);
attempt_swap_as_chunks!(u8);
}

// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
}
)
}

/// Same behavior and safety conditions as [`swap_nonoverlapping`]
///
/// LLVM can vectorize this (at least it can for the power-of-two-sized types
/// `swap_nonoverlapping` tries to use) so no need to manually SIMD it.
#[inline]
const unsafe fn swap_nonoverlapping_simple_untyped<T>(x: *mut T, y: *mut T, count: usize) {
let x = x.cast::<MaybeUninit<T>>();
let y = y.cast::<MaybeUninit<T>>();
const unsafe fn swap_nonoverlapping_const<T>(x: *mut T, y: *mut T, count: usize) {
let mut i = 0;
while i < count {
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
Expand All @@ -1147,26 +1122,91 @@ const unsafe fn swap_nonoverlapping_simple_untyped<T>(x: *mut T, y: *mut T, coun
// and it's distinct from `x` since the ranges are non-overlapping
let y = unsafe { y.add(i) };

// If we end up here, it's because we're using a simple type -- like
// a small power-of-two-sized thing -- or a special type with particularly
// large alignment, particularly SIMD types.
// Thus, we're fine just reading-and-writing it, as either it's small
// and that works well anyway or it's special and the type's author
// presumably wanted things to be done in the larger chunk.

// SAFETY: we're only ever given pointers that are valid to read/write,
// including being aligned, and nothing here panics so it's drop-safe.
unsafe {
let a: MaybeUninit<T> = read(x);
let b: MaybeUninit<T> = read(y);
write(x, b);
write(y, a);
// Note that it's critical that these use `copy_nonoverlapping`,
// rather than `read`/`write`, to avoid #134713 if T has padding.
let mut temp = MaybeUninit::<T>::uninit();
copy_nonoverlapping(x, temp.as_mut_ptr(), 1);
copy_nonoverlapping(y, x, 1);
copy_nonoverlapping(temp.as_ptr(), y, 1);
}

i += 1;
}
}

// Don't let MIR inline this, because we really want it to keep its noalias metadata
#[rustc_no_mir_inline]
#[inline]
fn swap_chunk<const N: usize>(x: &mut MaybeUninit<[u8; N]>, y: &mut MaybeUninit<[u8; N]>) {
let a = *x;
let b = *y;
*x = b;
*y = a;
}

#[inline]
unsafe fn swap_nonoverlapping_bytes(x: *mut u8, y: *mut u8, bytes: NonZero<usize>) {
// Same as `swap_nonoverlapping::<[u8; N]>`.
unsafe fn swap_nonoverlapping_chunks<const N: usize>(
x: *mut MaybeUninit<[u8; N]>,
y: *mut MaybeUninit<[u8; N]>,
chunks: NonZero<usize>,
) {
let chunks = chunks.get();
for i in 0..chunks {
// SAFETY: i is in [0, chunks) so the adds and dereferences are in-bounds.
unsafe { swap_chunk(&mut *x.add(i), &mut *y.add(i)) };
}
}

// Same as `swap_nonoverlapping_bytes`, but accepts at most 1+2+4=7 bytes
#[inline]
unsafe fn swap_nonoverlapping_short(x: *mut u8, y: *mut u8, bytes: NonZero<usize>) {
// Tail handling for auto-vectorized code sometimes has element-at-a-time behaviour,
// see <https://github.com/rust-lang/rust/issues/134946>.
// By swapping as different sizes, rather than as a loop over bytes,
// we make sure not to end up with, say, seven byte-at-a-time copies.

let bytes = bytes.get();
let mut i = 0;
macro_rules! swap_prefix {
($($n:literal)+) => {$(
if (bytes & $n) != 0 {
// SAFETY: `i` can only have the same bits set as those in bytes,
// so these `add`s are in-bounds of `bytes`. But the bit for
// `$n` hasn't been set yet, so the `$n` bytes that `swap_chunk`
// will read and write are within the usable range.
unsafe { swap_chunk::<$n>(&mut*x.add(i).cast(), &mut*y.add(i).cast()) };
i |= $n;
}
)+};
}
swap_prefix!(4 2 1);
debug_assert_eq!(i, bytes);
}

const CHUNK_SIZE: usize = size_of::<*const ()>();
let bytes = bytes.get();

let chunks = bytes / CHUNK_SIZE;
let tail = bytes % CHUNK_SIZE;
if let Some(chunks) = NonZero::new(chunks) {
// SAFETY: this is bytes/CHUNK_SIZE*CHUNK_SIZE bytes, which is <= bytes,
// so it's within the range of our non-overlapping bytes.
unsafe { swap_nonoverlapping_chunks::<CHUNK_SIZE>(x.cast(), y.cast(), chunks) };
}
if let Some(tail) = NonZero::new(tail) {
const { assert!(CHUNK_SIZE <= 8) };
let delta = chunks * CHUNK_SIZE;
// SAFETY: the tail length is below CHUNK SIZE because of the remainder,
// and CHUNK_SIZE is at most 8 by the const assert, so tail <= 7
unsafe { swap_nonoverlapping_short(x.add(delta), y.add(delta), tail) };
}
}

/// Moves `src` into the pointed `dst`, returning the previous `dst` value.
///
/// Neither value is dropped.
Expand Down
36 changes: 36 additions & 0 deletions library/coretests/tests/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -984,3 +984,39 @@ fn test_ptr_metadata_in_const() {
assert_eq!(SLICE_META, 3);
assert_eq!(DYN_META.size_of(), 42);
}

// See <https://github.com/rust-lang/rust/issues/134713>
const fn ptr_swap_nonoverlapping_is_untyped_inner() {
#[repr(C)]
struct HasPadding(usize, u8);

let buf1: [usize; 2] = [1000, 2000];
let buf2: [usize; 2] = [3000, 4000];

// HasPadding and [usize; 2] have the same size and alignment,
// so swap_nonoverlapping should treat them the same
assert!(size_of::<HasPadding>() == size_of::<[usize; 2]>());
assert!(align_of::<HasPadding>() == align_of::<[usize; 2]>());

let mut b1 = buf1;
let mut b2 = buf2;
// Safety: b1 and b2 are distinct local variables,
// with the same size and alignment as HasPadding.
unsafe {
std::ptr::swap_nonoverlapping(
b1.as_mut_ptr().cast::<HasPadding>(),
b2.as_mut_ptr().cast::<HasPadding>(),
1,
);
}
assert!(b1[0] == buf2[0]);
assert!(b1[1] == buf2[1]);
assert!(b2[0] == buf1[0]);
assert!(b2[1] == buf1[1]);
}

#[test]
fn test_ptr_swap_nonoverlapping_is_untyped() {
ptr_swap_nonoverlapping_is_untyped_inner();
const { ptr_swap_nonoverlapping_is_untyped_inner() };
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use std::mem::{size_of, align_of};

// See <https://github.com/rust-lang/rust/issues/134713>

#[repr(C)]
struct Foo(usize, u8);

fn main() {
let buf1: [usize; 2] = [1000, 2000];
let buf2: [usize; 2] = [3000, 4000];

// Foo and [usize; 2] have the same size and alignment,
// so swap_nonoverlapping should treat them the same
assert_eq!(size_of::<Foo>(), size_of::<[usize; 2]>());
assert_eq!(align_of::<Foo>(), align_of::<[usize; 2]>());

let mut b1 = buf1;
let mut b2 = buf2;
// Safety: b1 and b2 are distinct local variables,
// with the same size and alignment as Foo.
unsafe {
std::ptr::swap_nonoverlapping(
b1.as_mut_ptr().cast::<Foo>(),
b2.as_mut_ptr().cast::<Foo>(),
1,
);
}
assert_eq!(b1, buf2);
assert_eq!(b2, buf1);
}
28 changes: 28 additions & 0 deletions tests/assembly/x86_64-typed-swap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,31 @@ pub fn swap_simd(x: &mut __m128, y: &mut __m128) {
// CHECK-NEXT: retq
swap(x, y)
}

// CHECK-LABEL: swap_string:
#[no_mangle]
pub fn swap_string(x: &mut String, y: &mut String) {
// CHECK-NOT: mov
// CHECK-COUNT-4: movups
// CHECK-NOT: mov
// CHECK-COUNT-4: movq
// CHECK-NOT: mov
swap(x, y)
}

// CHECK-LABEL: swap_44_bytes:
#[no_mangle]
pub fn swap_44_bytes(x: &mut [u8; 44], y: &mut [u8; 44]) {
// Ensure we do better than a long run of byte copies,
// see <https://github.com/rust-lang/rust/issues/134946>

// CHECK-NOT: movb
// CHECK-COUNT-8: movups{{.+}}xmm
// CHECK-NOT: movb
// CHECK-COUNT-4: movq
// CHECK-NOT: movb
// CHECK-COUNT-4: movl
// CHECK-NOT: movb
// CHECK: retq
Comment on lines +69 to +79
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note for reviewers: the codegen tests here are more about demonstrating what actually happens on a variety of types, and the exact details don't matter that much.

Reviewing the rust code is enough to know that LLVM will swap it, but for example here what we're trying to see is that it's not just a huge row of movbs like you can see in https://rust.godbolt.org/z/MKfxn1Tjr

swap(x, y)
}
8 changes: 4 additions & 4 deletions tests/codegen/simd/swap-simd-types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ pub fn swap_single_m256(x: &mut __m256, y: &mut __m256) {
#[no_mangle]
pub fn swap_m256_slice(x: &mut [__m256], y: &mut [__m256]) {
// CHECK-NOT: alloca
// CHECK: load <8 x float>{{.+}}align 32
// CHECK: store <8 x float>{{.+}}align 32
// CHECK-COUNT-2: load <4 x i64>{{.+}}align 32
// CHECK-COUNT-2: store <4 x i64>{{.+}}align 32
if x.len() == y.len() {
x.swap_with_slice(y);
}
Expand All @@ -34,7 +34,7 @@ pub fn swap_m256_slice(x: &mut [__m256], y: &mut [__m256]) {
#[no_mangle]
pub fn swap_bytes32(x: &mut [u8; 32], y: &mut [u8; 32]) {
// CHECK-NOT: alloca
// CHECK: load <32 x i8>{{.+}}align 1
// CHECK: store <32 x i8>{{.+}}align 1
// CHECK-COUNT-2: load <4 x i64>{{.+}}align 1
// CHECK-COUNT-2: store <4 x i64>{{.+}}align 1
swap(x, y)
}
Loading
Loading