Skip to content

Commit 28f9fe3

Browse files
committed
Auto merge of #3204 - RalfJung:simd, r=RalfJung
add new SIMD intrinsics
2 parents 2a316c4 + 2903f1c commit 28f9fe3

File tree

2 files changed

+74
-5
lines changed

2 files changed

+74
-5
lines changed

src/tools/miri/src/shims/intrinsics/simd.rs

+40-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use rustc_apfloat::{Float, Round};
22
use rustc_middle::ty::layout::{HasParamEnv, LayoutOf};
33
use rustc_middle::{mir, ty, ty::FloatTy};
4+
use rustc_span::{sym, Symbol};
45
use rustc_target::abi::{Endian, HasDataLayout};
56

67
use crate::*;
@@ -25,7 +26,12 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
2526
| "floor"
2627
| "round"
2728
| "trunc"
28-
| "fsqrt" => {
29+
| "fsqrt"
30+
| "ctlz"
31+
| "cttz"
32+
| "bswap"
33+
| "bitreverse"
34+
=> {
2935
let [op] = check_arg_count(args)?;
3036
let (op, op_len) = this.operand_to_simd(op)?;
3137
let (dest, dest_len) = this.place_to_simd(dest)?;
@@ -38,6 +44,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
3844
Abs,
3945
Sqrt,
4046
Round(rustc_apfloat::Round),
47+
Numeric(Symbol),
4148
}
4249
let which = match intrinsic_name {
4350
"neg" => Op::MirOp(mir::UnOp::Neg),
@@ -47,6 +54,10 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
4754
"floor" => Op::Round(rustc_apfloat::Round::TowardNegative),
4855
"round" => Op::Round(rustc_apfloat::Round::NearestTiesToAway),
4956
"trunc" => Op::Round(rustc_apfloat::Round::TowardZero),
57+
"ctlz" => Op::Numeric(sym::ctlz),
58+
"cttz" => Op::Numeric(sym::cttz),
59+
"bswap" => Op::Numeric(sym::bswap),
60+
"bitreverse" => Op::Numeric(sym::bitreverse),
5061
_ => unreachable!(),
5162
};
5263

@@ -101,6 +112,20 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
101112
}
102113
}
103114
}
115+
Op::Numeric(name) => {
116+
assert!(op.layout.ty.is_integral());
117+
let size = op.layout.size;
118+
let bits = op.to_scalar().to_bits(size).unwrap();
119+
let extra = 128u128.checked_sub(u128::from(size.bits())).unwrap();
120+
let bits_out = match name {
121+
sym::ctlz => u128::from(bits.leading_zeros()).checked_sub(extra).unwrap(),
122+
sym::cttz => u128::from((bits << extra).trailing_zeros()).checked_sub(extra).unwrap(),
123+
sym::bswap => (bits << extra).swap_bytes(),
124+
sym::bitreverse => (bits << extra).reverse_bits(),
125+
_ => unreachable!(),
126+
};
127+
Scalar::from_uint(bits_out, size)
128+
}
104129
};
105130
this.write_scalar(val, &dest)?;
106131
}
@@ -126,7 +151,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
126151
| "fmin"
127152
| "saturating_add"
128153
| "saturating_sub"
129-
| "arith_offset" => {
154+
| "arith_offset"
155+
=> {
130156
use mir::BinOp;
131157

132158
let [left, right] = check_arg_count(args)?;
@@ -386,16 +412,25 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
386412
let (dest, dest_len) = this.place_to_simd(dest)?;
387413
let bitmask_len = dest_len.max(8);
388414

389-
assert!(mask.layout.ty.is_integral());
390415
assert!(bitmask_len <= 64);
391416
assert_eq!(bitmask_len, mask.layout.size.bits());
392417
assert_eq!(dest_len, yes_len);
393418
assert_eq!(dest_len, no_len);
394419
let dest_len = u32::try_from(dest_len).unwrap();
395420
let bitmask_len = u32::try_from(bitmask_len).unwrap();
396421

397-
let mask: u64 =
398-
this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap();
422+
// The mask can be a single integer or an array.
423+
let mask: u64 = match mask.layout.ty.kind() {
424+
ty::Int(..) | ty::Uint(..) =>
425+
this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap(),
426+
ty::Array(elem, _) if matches!(elem.kind(), ty::Uint(ty::UintTy::U8)) => {
427+
let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
428+
let mask = mask.transmute(mask_ty, this)?;
429+
this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap()
430+
}
431+
_ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
432+
};
433+
399434
for i in 0..dest_len {
400435
let mask = mask
401436
& 1u64

src/tools/miri/tests/pass/portable-simd.rs

+34
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,24 @@ fn simd_ops_i32() {
197197
assert_eq!(b.reduce_or(), -1);
198198
assert_eq!(a.reduce_xor(), 0);
199199
assert_eq!(b.reduce_xor(), -4);
200+
201+
assert_eq!(b.leading_zeros(), u32x4::from_array([31, 30, 30, 0]));
202+
assert_eq!(b.trailing_zeros(), u32x4::from_array([0, 1, 0, 2]));
203+
assert_eq!(b.leading_ones(), u32x4::from_array([0, 0, 0, 30]));
204+
assert_eq!(b.trailing_ones(), u32x4::from_array([1, 0, 2, 0]));
205+
assert_eq!(
206+
b.swap_bytes(),
207+
i32x4::from_array([0x01000000, 0x02000000, 0x03000000, 0xfcffffffu32 as i32])
208+
);
209+
assert_eq!(
210+
b.reverse_bits(),
211+
i32x4::from_array([
212+
0x80000000u32 as i32,
213+
0x40000000,
214+
0xc0000000u32 as i32,
215+
0x3fffffffu32 as i32
216+
])
217+
);
200218
}
201219

202220
fn simd_mask() {
@@ -247,6 +265,22 @@ fn simd_mask() {
247265
assert_eq!(bitmask2, [0b0001]);
248266
}
249267
}
268+
269+
// This used to cause an ICE.
270+
let bitmask = u8x8::from_array([0b01000101, 0, 0, 0, 0, 0, 0, 0]);
271+
assert_eq!(
272+
mask32x8::from_bitmask_vector(bitmask),
273+
mask32x8::from_array([true, false, true, false, false, false, true, false]),
274+
);
275+
let bitmask =
276+
u8x16::from_array([0b01000101, 0b11110000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
277+
assert_eq!(
278+
mask32x16::from_bitmask_vector(bitmask),
279+
mask32x16::from_array([
280+
true, false, true, false, false, false, true, false, false, false, false, false, true,
281+
true, true, true,
282+
]),
283+
);
250284
}
251285

252286
fn simd_cast() {

0 commit comments

Comments
 (0)