@@ -398,6 +398,7 @@ use crate::cmp::Ordering;
398
398
use crate :: intrinsics:: const_eval_select;
399
399
use crate :: marker:: FnPtr ;
400
400
use crate :: mem:: { self , MaybeUninit , SizedTypeProperties } ;
401
+ use crate :: num:: NonZero ;
401
402
use crate :: { fmt, hash, intrinsics, ub_checks} ;
402
403
403
404
mod alignment;
@@ -1094,51 +1095,25 @@ pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
1094
1095
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
1095
1096
// of a pointer (which would not work).
1096
1097
// SAFETY: Same preconditions as this function
1097
- unsafe { swap_nonoverlapping_simple_untyped ( x, y, count) }
1098
+ unsafe { swap_nonoverlapping_const ( x, y, count) }
1098
1099
} else {
1099
- macro_rules! attempt_swap_as_chunks {
1100
- ( $ChunkTy : ty) => {
1101
- if align_of:: <T >( ) >= align_of:: <$ChunkTy >( )
1102
- && size_of:: <T >( ) % size_of:: <$ChunkTy >( ) == 0
1103
- {
1104
- let x: * mut $ChunkTy = x. cast( ) ;
1105
- let y: * mut $ChunkTy = y. cast( ) ;
1106
- let count = count * ( size_of:: <T >( ) / size_of:: <$ChunkTy >( ) ) ;
1107
- // SAFETY: these are the same bytes that the caller promised were
1108
- // ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
1109
- // The `if` condition above ensures that we're not violating
1110
- // alignment requirements, and that the division is exact so
1111
- // that we don't lose any bytes off the end.
1112
- return unsafe { swap_nonoverlapping_simple_untyped( x, y, count) } ;
1113
- }
1114
- } ;
1100
+ // Going though a slice here helps codegen know the size fits in `isize`
1101
+ let slice = slice_from_raw_parts_mut( x, count) ;
1102
+ // SAFETY: This is all readable from the pointer, meaning it's one
1103
+ // allocated object, and thus cannot be more than isize::MAX bytes.
1104
+ let bytes = unsafe { mem:: size_of_val_raw:: <[ T ] >( slice) } ;
1105
+ if let Some ( bytes) = NonZero :: new( bytes) {
1106
+ // SAFETY: These are the same ranges, just expressed in a different
1107
+ // type, so they're still non-overlapping.
1108
+ unsafe { swap_nonoverlapping_bytes( x. cast( ) , y. cast( ) , bytes) } ;
1115
1109
}
1116
-
1117
- // Split up the slice into small power-of-two-sized chunks that LLVM is able
1118
- // to vectorize (unless it's a special type with more-than-pointer alignment,
1119
- // because we don't want to pessimize things like slices of SIMD vectors.)
1120
- if align_of:: <T >( ) <= size_of:: <usize >( )
1121
- && ( !size_of:: <T >( ) . is_power_of_two( )
1122
- || size_of:: <T >( ) > size_of:: <usize >( ) * 2 )
1123
- {
1124
- attempt_swap_as_chunks!( usize ) ;
1125
- attempt_swap_as_chunks!( u8 ) ;
1126
- }
1127
-
1128
- // SAFETY: Same preconditions as this function
1129
- unsafe { swap_nonoverlapping_simple_untyped( x, y, count) }
1130
1110
}
1131
1111
)
1132
1112
}
1133
1113
1134
1114
/// Same behavior and safety conditions as [`swap_nonoverlapping`]
1135
- ///
1136
- /// LLVM can vectorize this (at least it can for the power-of-two-sized types
1137
- /// `swap_nonoverlapping` tries to use) so no need to manually SIMD it.
1138
1115
#[ inline]
1139
- const unsafe fn swap_nonoverlapping_simple_untyped < T > ( x : * mut T , y : * mut T , count : usize ) {
1140
- let x = x. cast :: < MaybeUninit < T > > ( ) ;
1141
- let y = y. cast :: < MaybeUninit < T > > ( ) ;
1116
+ const unsafe fn swap_nonoverlapping_const < T > ( x : * mut T , y : * mut T , count : usize ) {
1142
1117
let mut i = 0 ;
1143
1118
while i < count {
1144
1119
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
@@ -1147,26 +1122,91 @@ const unsafe fn swap_nonoverlapping_simple_untyped<T>(x: *mut T, y: *mut T, coun
1147
1122
// and it's distinct from `x` since the ranges are non-overlapping
1148
1123
let y = unsafe { y. add ( i) } ;
1149
1124
1150
- // If we end up here, it's because we're using a simple type -- like
1151
- // a small power-of-two-sized thing -- or a special type with particularly
1152
- // large alignment, particularly SIMD types.
1153
- // Thus, we're fine just reading-and-writing it, as either it's small
1154
- // and that works well anyway or it's special and the type's author
1155
- // presumably wanted things to be done in the larger chunk.
1156
-
1157
1125
// SAFETY: we're only ever given pointers that are valid to read/write,
1158
1126
// including being aligned, and nothing here panics so it's drop-safe.
1159
1127
unsafe {
1160
- let a: MaybeUninit < T > = read ( x) ;
1161
- let b: MaybeUninit < T > = read ( y) ;
1162
- write ( x, b) ;
1163
- write ( y, a) ;
1128
+ // Note that it's critical that these use `copy_nonoverlapping`,
1129
+ // rather than `read`/`write`, to avoid #134713 if T has padding.
1130
+ let mut temp = MaybeUninit :: < T > :: uninit ( ) ;
1131
+ copy_nonoverlapping ( x, temp. as_mut_ptr ( ) , 1 ) ;
1132
+ copy_nonoverlapping ( y, x, 1 ) ;
1133
+ copy_nonoverlapping ( temp. as_ptr ( ) , y, 1 ) ;
1164
1134
}
1165
1135
1166
1136
i += 1 ;
1167
1137
}
1168
1138
}
1169
1139
1140
+ // Don't let MIR inline this, because we really want it to keep its noalias metadata
1141
+ #[ rustc_no_mir_inline]
1142
+ #[ inline]
1143
+ fn swap_chunk < const N : usize > ( x : & mut MaybeUninit < [ u8 ; N ] > , y : & mut MaybeUninit < [ u8 ; N ] > ) {
1144
+ let a = * x;
1145
+ let b = * y;
1146
+ * x = b;
1147
+ * y = a;
1148
+ }
1149
+
1150
+ #[ inline]
1151
+ unsafe fn swap_nonoverlapping_bytes ( x : * mut u8 , y : * mut u8 , bytes : NonZero < usize > ) {
1152
+ // Same as `swap_nonoverlapping::<[u8; N]>`.
1153
+ unsafe fn swap_nonoverlapping_chunks < const N : usize > (
1154
+ x : * mut MaybeUninit < [ u8 ; N ] > ,
1155
+ y : * mut MaybeUninit < [ u8 ; N ] > ,
1156
+ chunks : NonZero < usize > ,
1157
+ ) {
1158
+ let chunks = chunks. get ( ) ;
1159
+ for i in 0 ..chunks {
1160
+ // SAFETY: i is in [0, chunks) so the adds and dereferences are in-bounds.
1161
+ unsafe { swap_chunk ( & mut * x. add ( i) , & mut * y. add ( i) ) } ;
1162
+ }
1163
+ }
1164
+
1165
+ // Same as `swap_nonoverlapping_bytes`, but accepts at most 1+2+4=7 bytes
1166
+ #[ inline]
1167
+ unsafe fn swap_nonoverlapping_short ( x : * mut u8 , y : * mut u8 , bytes : NonZero < usize > ) {
1168
+ // Tail handling for auto-vectorized code sometimes has element-at-a-time behaviour,
1169
+ // see <https://github.com/rust-lang/rust/issues/134946>.
1170
+ // By swapping as different sizes, rather than as a loop over bytes,
1171
+ // we make sure not to end up with, say, seven byte-at-a-time copies.
1172
+
1173
+ let bytes = bytes. get ( ) ;
1174
+ let mut i = 0 ;
1175
+ macro_rules! swap_prefix {
1176
+ ( $( $n: literal) +) => { $(
1177
+ if ( bytes & $n) != 0 {
1178
+ // SAFETY: `i` can only have the same bits set as those in bytes,
1179
+ // so these `add`s are in-bounds of `bytes`. But the bit for
1180
+ // `$n` hasn't been set yet, so the `$n` bytes that `swap_chunk`
1181
+ // will read and write are within the usable range.
1182
+ unsafe { swap_chunk:: <$n>( & mut * x. add( i) . cast( ) , & mut * y. add( i) . cast( ) ) } ;
1183
+ i |= $n;
1184
+ }
1185
+ ) +} ;
1186
+ }
1187
+ swap_prefix ! ( 4 2 1 ) ;
1188
+ debug_assert_eq ! ( i, bytes) ;
1189
+ }
1190
+
1191
+ const CHUNK_SIZE : usize = size_of :: < * const ( ) > ( ) ;
1192
+ let bytes = bytes. get ( ) ;
1193
+
1194
+ let chunks = bytes / CHUNK_SIZE ;
1195
+ let tail = bytes % CHUNK_SIZE ;
1196
+ if let Some ( chunks) = NonZero :: new ( chunks) {
1197
+ // SAFETY: this is bytes/CHUNK_SIZE*CHUNK_SIZE bytes, which is <= bytes,
1198
+ // so it's within the range of our non-overlapping bytes.
1199
+ unsafe { swap_nonoverlapping_chunks :: < CHUNK_SIZE > ( x. cast ( ) , y. cast ( ) , chunks) } ;
1200
+ }
1201
+ if let Some ( tail) = NonZero :: new ( tail) {
1202
+ const { assert ! ( CHUNK_SIZE <= 8 ) } ;
1203
+ let delta = chunks * CHUNK_SIZE ;
1204
+ // SAFETY: the tail length is below CHUNK SIZE because of the remainder,
1205
+ // and CHUNK_SIZE is at most 8 by the const assert, so tail <= 7
1206
+ unsafe { swap_nonoverlapping_short ( x. add ( delta) , y. add ( delta) , tail) } ;
1207
+ }
1208
+ }
1209
+
1170
1210
/// Moves `src` into the pointed `dst`, returning the previous `dst` value.
1171
1211
///
1172
1212
/// Neither value is dropped.
0 commit comments