Skip to content

Commit addae07

Browse files
committed
Auto merge of rust-lang#138391 - scottmcm:SSA-discriminants, r=WaffleLapkin
Don't `alloca` just to look at a discriminant Today we're making LLVM do a bunch of extra work when you match on trivial stuff like `Option<bool>` or `ControlFlow<u8>`. This PR changes that so that simple types like `Option<u32>` or `Result<(), Box<Error>>` can stay as `OperandValue::ScalarPair` and we can still read the discriminant from them, rather than needing to write them into memory to have a `PlaceValue` just to get the discriminant out. Fixes rust-lang#137503
2 parents cbfdf0b + 2b15dd1 commit addae07

File tree

9 files changed

+186
-167
lines changed

9 files changed

+186
-167
lines changed

compiler/rustc_codegen_ssa/src/mir/analyze.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,12 @@ impl<'a, 'b, 'tcx, Bx: BuilderMethods<'b, 'tcx>> Visitor<'tcx> for LocalAnalyzer
205205
| PlaceContext::MutatingUse(MutatingUseContext::Retag) => {}
206206

207207
PlaceContext::NonMutatingUse(
208-
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
208+
NonMutatingUseContext::Copy
209+
| NonMutatingUseContext::Move
210+
// Inspect covers things like `PtrMetadata` and `Discriminant`
211+
// which we can treat similar to `Copy` use for the purpose of
212+
// whether we can use SSA variables for things.
213+
| NonMutatingUseContext::Inspect,
209214
) => match &mut self.locals[local] {
210215
LocalKind::ZST => {}
211216
LocalKind::Memory => {}
@@ -229,8 +234,7 @@ impl<'a, 'b, 'tcx, Bx: BuilderMethods<'b, 'tcx>> Visitor<'tcx> for LocalAnalyzer
229234
| MutatingUseContext::Projection,
230235
)
231236
| PlaceContext::NonMutatingUse(
232-
NonMutatingUseContext::Inspect
233-
| NonMutatingUseContext::SharedBorrow
237+
NonMutatingUseContext::SharedBorrow
234238
| NonMutatingUseContext::FakeBorrow
235239
| NonMutatingUseContext::RawBorrow
236240
| NonMutatingUseContext::Projection,

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

+1-9
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
6262
let callee_ty = instance.ty(bx.tcx(), bx.typing_env());
6363

6464
let ty::FnDef(def_id, fn_args) = *callee_ty.kind() else {
65-
bug!("expected fn item type, found {}", callee_ty);
65+
span_bug!(span, "expected fn item type, found {}", callee_ty);
6666
};
6767

6868
let sig = callee_ty.fn_sig(bx.tcx());
@@ -325,14 +325,6 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
325325
}
326326
}
327327

328-
sym::discriminant_value => {
329-
if ret_ty.is_integral() {
330-
args[0].deref(bx.cx()).codegen_get_discr(bx, ret_ty)
331-
} else {
332-
span_bug!(span, "Invalid discriminant type for `{:?}`", arg_tys[0])
333-
}
334-
}
335-
336328
// This requires that atomic intrinsics follow a specific naming pattern:
337329
// "atomic_<operation>[_<ordering>]"
338330
name if let Some(atomic) = name_str.strip_prefix("atomic_") => {

compiler/rustc_codegen_ssa/src/mir/operand.rs

+146-2
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@ use std::fmt;
33
use arrayvec::ArrayVec;
44
use either::Either;
55
use rustc_abi as abi;
6-
use rustc_abi::{Align, BackendRepr, Size};
6+
use rustc_abi::{Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, Variants};
77
use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range};
88
use rustc_middle::mir::{self, ConstValue};
99
use rustc_middle::ty::Ty;
1010
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
1111
use rustc_middle::{bug, span_bug};
12-
use tracing::debug;
12+
use tracing::{debug, instrument};
1313

1414
use super::place::{PlaceRef, PlaceValue};
1515
use super::{FunctionCx, LocalRef};
16+
use crate::common::IntPredicate;
1617
use crate::traits::*;
1718
use crate::{MemFlags, size_of_val};
1819

@@ -415,6 +416,149 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
415416

416417
OperandRef { val, layout: field }
417418
}
419+
420+
/// Obtain the actual discriminant of a value.
421+
#[instrument(level = "trace", skip(fx, bx))]
422+
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
423+
self,
424+
fx: &mut FunctionCx<'a, 'tcx, Bx>,
425+
bx: &mut Bx,
426+
cast_to: Ty<'tcx>,
427+
) -> V {
428+
let dl = &bx.tcx().data_layout;
429+
let cast_to_layout = bx.cx().layout_of(cast_to);
430+
let cast_to = bx.cx().immediate_backend_type(cast_to_layout);
431+
432+
// We check uninhabitedness separately because a type like
433+
// `enum Foo { Bar(i32, !) }` is still reported as `Variants::Single`,
434+
// *not* as `Variants::Empty`.
435+
if self.layout.is_uninhabited() {
436+
return bx.cx().const_poison(cast_to);
437+
}
438+
439+
let (tag_scalar, tag_encoding, tag_field) = match self.layout.variants {
440+
Variants::Empty => unreachable!("we already handled uninhabited types"),
441+
Variants::Single { index } => {
442+
let discr_val =
443+
if let Some(discr) = self.layout.ty.discriminant_for_variant(bx.tcx(), index) {
444+
discr.val
445+
} else {
446+
// This arm is for types which are neither enums nor coroutines,
447+
// and thus for which the only possible "variant" should be the first one.
448+
assert_eq!(index, FIRST_VARIANT);
449+
// There's thus no actual discriminant to return, so we return
450+
// what it would have been if this was a single-variant enum.
451+
0
452+
};
453+
return bx.cx().const_uint_big(cast_to, discr_val);
454+
}
455+
Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => {
456+
(tag, tag_encoding, tag_field)
457+
}
458+
};
459+
460+
// Read the tag/niche-encoded discriminant from memory.
461+
let tag_op = match self.val {
462+
OperandValue::ZeroSized => bug!(),
463+
OperandValue::Immediate(_) | OperandValue::Pair(_, _) => {
464+
self.extract_field(fx, bx, tag_field)
465+
}
466+
OperandValue::Ref(place) => {
467+
let tag = place.with_type(self.layout).project_field(bx, tag_field);
468+
bx.load_operand(tag)
469+
}
470+
};
471+
let tag_imm = tag_op.immediate();
472+
473+
// Decode the discriminant (specifically if it's niche-encoded).
474+
match *tag_encoding {
475+
TagEncoding::Direct => {
476+
let signed = match tag_scalar.primitive() {
477+
// We use `i1` for bytes that are always `0` or `1`,
478+
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
479+
// let LLVM interpret the `i1` as signed, because
480+
// then `i1 1` (i.e., `E::B`) is effectively `i8 -1`.
481+
Primitive::Int(_, signed) => !tag_scalar.is_bool() && signed,
482+
_ => false,
483+
};
484+
bx.intcast(tag_imm, cast_to, signed)
485+
}
486+
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
487+
// Cast to an integer so we don't have to treat a pointer as a
488+
// special case.
489+
let (tag, tag_llty) = match tag_scalar.primitive() {
490+
// FIXME(erikdesjardins): handle non-default addrspace ptr sizes
491+
Primitive::Pointer(_) => {
492+
let t = bx.type_from_integer(dl.ptr_sized_integer());
493+
let tag = bx.ptrtoint(tag_imm, t);
494+
(tag, t)
495+
}
496+
_ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)),
497+
};
498+
499+
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
500+
501+
// We have a subrange `niche_start..=niche_end` inside `range`.
502+
// If the value of the tag is inside this subrange, it's a
503+
// "niche value", an increment of the discriminant. Otherwise it
504+
// indicates the untagged variant.
505+
// A general algorithm to extract the discriminant from the tag
506+
// is:
507+
// relative_tag = tag - niche_start
508+
// is_niche = relative_tag <= (ule) relative_max
509+
// discr = if is_niche {
510+
// cast(relative_tag) + niche_variants.start()
511+
// } else {
512+
// untagged_variant
513+
// }
514+
// However, we will likely be able to emit simpler code.
515+
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
516+
// Best case scenario: only one tagged variant. This will
517+
// likely become just a comparison and a jump.
518+
// The algorithm is:
519+
// is_niche = tag == niche_start
520+
// discr = if is_niche {
521+
// niche_start
522+
// } else {
523+
// untagged_variant
524+
// }
525+
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
526+
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
527+
let tagged_discr =
528+
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
529+
(is_niche, tagged_discr, 0)
530+
} else {
531+
// The special cases don't apply, so we'll have to go with
532+
// the general algorithm.
533+
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
534+
let cast_tag = bx.intcast(relative_discr, cast_to, false);
535+
let is_niche = bx.icmp(
536+
IntPredicate::IntULE,
537+
relative_discr,
538+
bx.cx().const_uint(tag_llty, relative_max as u64),
539+
);
540+
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
541+
};
542+
543+
let tagged_discr = if delta == 0 {
544+
tagged_discr
545+
} else {
546+
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
547+
};
548+
549+
let discr = bx.select(
550+
is_niche,
551+
tagged_discr,
552+
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
553+
);
554+
555+
// In principle we could insert assumes on the possible range of `discr`, but
556+
// currently in LLVM this seems to be a pessimization.
557+
558+
discr
559+
}
560+
}
561+
}
418562
}
419563

420564
impl<'a, 'tcx, V: CodegenObject> OperandValue<V> {

compiler/rustc_codegen_ssa/src/mir/place.rs

-124
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use rustc_abi::Primitive::{Int, Pointer};
21
use rustc_abi::{Align, BackendRepr, FieldsShape, Size, TagEncoding, VariantIdx, Variants};
32
use rustc_middle::mir::PlaceTy;
43
use rustc_middle::mir::interpret::Scalar;
@@ -233,129 +232,6 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
233232
val.with_type(field)
234233
}
235234

236-
/// Obtain the actual discriminant of a value.
237-
#[instrument(level = "trace", skip(bx))]
238-
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
239-
self,
240-
bx: &mut Bx,
241-
cast_to: Ty<'tcx>,
242-
) -> V {
243-
let dl = &bx.tcx().data_layout;
244-
let cast_to_layout = bx.cx().layout_of(cast_to);
245-
let cast_to = bx.cx().immediate_backend_type(cast_to_layout);
246-
if self.layout.is_uninhabited() {
247-
return bx.cx().const_poison(cast_to);
248-
}
249-
let (tag_scalar, tag_encoding, tag_field) = match self.layout.variants {
250-
Variants::Empty => unreachable!("we already handled uninhabited types"),
251-
Variants::Single { index } => {
252-
let discr_val = self
253-
.layout
254-
.ty
255-
.discriminant_for_variant(bx.cx().tcx(), index)
256-
.map_or(index.as_u32() as u128, |discr| discr.val);
257-
return bx.cx().const_uint_big(cast_to, discr_val);
258-
}
259-
Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => {
260-
(tag, tag_encoding, tag_field)
261-
}
262-
};
263-
264-
// Read the tag/niche-encoded discriminant from memory.
265-
let tag = self.project_field(bx, tag_field);
266-
let tag_op = bx.load_operand(tag);
267-
let tag_imm = tag_op.immediate();
268-
269-
// Decode the discriminant (specifically if it's niche-encoded).
270-
match *tag_encoding {
271-
TagEncoding::Direct => {
272-
let signed = match tag_scalar.primitive() {
273-
// We use `i1` for bytes that are always `0` or `1`,
274-
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
275-
// let LLVM interpret the `i1` as signed, because
276-
// then `i1 1` (i.e., `E::B`) is effectively `i8 -1`.
277-
Int(_, signed) => !tag_scalar.is_bool() && signed,
278-
_ => false,
279-
};
280-
bx.intcast(tag_imm, cast_to, signed)
281-
}
282-
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
283-
// Cast to an integer so we don't have to treat a pointer as a
284-
// special case.
285-
let (tag, tag_llty) = match tag_scalar.primitive() {
286-
// FIXME(erikdesjardins): handle non-default addrspace ptr sizes
287-
Pointer(_) => {
288-
let t = bx.type_from_integer(dl.ptr_sized_integer());
289-
let tag = bx.ptrtoint(tag_imm, t);
290-
(tag, t)
291-
}
292-
_ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)),
293-
};
294-
295-
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
296-
297-
// We have a subrange `niche_start..=niche_end` inside `range`.
298-
// If the value of the tag is inside this subrange, it's a
299-
// "niche value", an increment of the discriminant. Otherwise it
300-
// indicates the untagged variant.
301-
// A general algorithm to extract the discriminant from the tag
302-
// is:
303-
// relative_tag = tag - niche_start
304-
// is_niche = relative_tag <= (ule) relative_max
305-
// discr = if is_niche {
306-
// cast(relative_tag) + niche_variants.start()
307-
// } else {
308-
// untagged_variant
309-
// }
310-
// However, we will likely be able to emit simpler code.
311-
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
312-
// Best case scenario: only one tagged variant. This will
313-
// likely become just a comparison and a jump.
314-
// The algorithm is:
315-
// is_niche = tag == niche_start
316-
// discr = if is_niche {
317-
// niche_start
318-
// } else {
319-
// untagged_variant
320-
// }
321-
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
322-
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
323-
let tagged_discr =
324-
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
325-
(is_niche, tagged_discr, 0)
326-
} else {
327-
// The special cases don't apply, so we'll have to go with
328-
// the general algorithm.
329-
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
330-
let cast_tag = bx.intcast(relative_discr, cast_to, false);
331-
let is_niche = bx.icmp(
332-
IntPredicate::IntULE,
333-
relative_discr,
334-
bx.cx().const_uint(tag_llty, relative_max as u64),
335-
);
336-
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
337-
};
338-
339-
let tagged_discr = if delta == 0 {
340-
tagged_discr
341-
} else {
342-
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
343-
};
344-
345-
let discr = bx.select(
346-
is_niche,
347-
tagged_discr,
348-
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
349-
);
350-
351-
// In principle we could insert assumes on the possible range of `discr`, but
352-
// currently in LLVM this seems to be a pessimization.
353-
354-
discr
355-
}
356-
}
357-
}
358-
359235
/// Sets the discriminant for a new value of the given case of the given
360236
/// representation.
361237
pub fn codegen_set_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(

compiler/rustc_codegen_ssa/src/mir/rvalue.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,8 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
706706
mir::Rvalue::Discriminant(ref place) => {
707707
let discr_ty = rvalue.ty(self.mir, bx.tcx());
708708
let discr_ty = self.monomorphize(discr_ty);
709-
let discr = self.codegen_place(bx, place.as_ref()).codegen_get_discr(bx, discr_ty);
709+
let operand = self.codegen_consume(bx, place.as_ref());
710+
let discr = operand.codegen_get_discr(self, bx, discr_ty);
710711
OperandRef {
711712
val: OperandValue::Immediate(discr),
712713
layout: self.cx.layout_of(discr_ty),

0 commit comments

Comments
 (0)