Skip to content

Re-enable the early otherwise branch optimization #122387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 53 additions & 119 deletions compiler/rustc_mir_transform/src/early_otherwise_branch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{Ty, TyCtxt};
use std::fmt::Debug;

use super::simplify::simplify_cfg;
Expand All @@ -11,6 +11,7 @@ use super::simplify::simplify_cfg;
/// let y: Option<()>;
/// match (x,y) {
/// (Some(_), Some(_)) => {0},
/// (None, None) => {2},
/// _ => {1}
/// }
/// ```
Expand All @@ -23,10 +24,10 @@ use super::simplify::simplify_cfg;
/// if discriminant_x == discriminant_y {
/// match x {
/// Some(_) => 0,
/// _ => 1, // <----
/// } // | Actually the same bb
/// } else { // |
/// 1 // <--------------
/// None => 2,
/// }
/// } else {
/// 1
/// }
/// ```
///
Expand All @@ -47,18 +48,18 @@ use super::simplify::simplify_cfg;
/// | | |
/// ================= | | |
/// | BBU | <-| | | ============================
/// |---------------| | \-------> | BBD |
/// |---------------| | | |--------------------------|
/// | unreachable | | | | _dl = discriminant(P) |
/// ================= | | |--------------------------|
/// | | | switchInt(_dl) |
/// ================= | | | d | ---> BBD.2
/// |---------------| \-------> | BBD |
/// |---------------| | |--------------------------|
/// | unreachable | | | _dl = discriminant(P) |
/// ================= | |--------------------------|
/// | | switchInt(_dl) |
/// ================= | | d | ---> BBD.2
/// | BB9 | <--------------- | otherwise |
/// |---------------| ============================
/// | ... |
/// =================
/// ```
/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
/// code:
/// - `BB1` is `parent` and `BBC, BBD` are children
/// - `P` is `child_place`
Expand All @@ -78,7 +79,7 @@ use super::simplify::simplify_cfg;
/// |---------------------| | | switchInt(Q) |
/// | switchInt(_t) | | | c | ---> BBC.2
/// | false | --------/ | d | ---> BBD.2
/// | otherwise | ---------------- | otherwise |
/// | otherwise | /--------- | otherwise |
/// ======================= | ============================
/// |
/// ================= |
Expand All @@ -87,16 +88,11 @@ use super::simplify::simplify_cfg;
/// | ... |
/// =================
/// ```
///
/// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
/// The filter on which `P` are allowed (together with discussion of its correctness) is found in
/// `may_hoist`.
pub struct EarlyOtherwiseBranch;

impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
// unsound: https://github.com/rust-lang/rust/issues/95162
sess.mir_opt_level() >= 3 && sess.opts.unstable_opts.unsound_mir_opts
sess.mir_opt_level() >= 2
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
Expand Down Expand Up @@ -172,7 +168,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
};
(value, targets.target_for_value(value))
});
let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination);
// The otherwise either is the same target branch or an unreachable.
let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());

// Create `bbEq` in example above
let eq_switch = BasicBlockData::new(Some(Terminator {
Expand Down Expand Up @@ -217,85 +214,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
}
}

/// Returns true if computing the discriminant of `place` may be hoisted out of the branch
fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool {
// FIXME(JakobDegen): This is unsound. Someone could write code like this:
// ```rust
// let Q = val;
// if discriminant(P) == otherwise {
// let ptr = &mut Q as *mut _ as *mut u8;
// unsafe { *ptr = 10; } // Any invalid value for the type
// }
//
// match P {
// A => match Q {
// A => {
// // code
// }
// _ => {
// // don't use Q
// }
// }
// _ => {
// // don't use Q
// }
// };
// ```
//
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
// invalid value, which is UB.
//
// In order to fix this, we would either need to show that the discriminant computation of
// `place` is computed in all branches, including the `otherwise` branch, or we would need
// another analysis pass to determine that the place is fully initialized. It might even be best
// to have the hoisting be performed in a different pass and just do the CFG changing in this
// pass.
for (place, proj) in place.iter_projections() {
match proj {
// Dereferencing in the computation of `place` might cause issues from one of two
// categories. First, the referent might be invalid. We protect against this by
// dereferencing references only (not pointers). Second, the use of a reference may
// invalidate other references that are used later (for aliasing reasons). Consider
// where such an invalidated reference may appear:
// - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
// cannot contain referenced data.
// - In `BBU`: Not possible since that block contains only the `unreachable` terminator
// - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
// reaching that block in the input to our transformation, and so any data
// invalidated by that computation could not have been used there.
// - In `BB9`: Not possible since control flow might have reached `BB9` via the
// `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
// have invalidated the data when computing `discriminant(P)`
// So dereferencing here is correct.
ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() {
ty::Ref(..) => {}
_ => return false,
},
// Field projections are always valid
ProjectionElem::Field(..) => {}
// We cannot allow
// downcasts either, since the correctness of the downcast may depend on the parent
// branch being taken. An easy example of this is
// ```
// Q = discriminant(_3)
// P = (_3 as Variant)
// ```
// However, checking if the child and parent place are the same and only erroring then
// is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
// be replaced by another optimization pass with any other condition that can be proven
// equivalent.
ProjectionElem::Downcast(..) => {
return false;
}
// We cannot allow indexing since the index may be out of bounds.
_ => {
return false;
}
}
}
true
}

#[derive(Debug)]
struct OptimizationData<'tcx> {
destination: BasicBlock,
Expand All @@ -315,18 +233,40 @@ fn evaluate_candidate<'tcx>(
return None;
};
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
let parent_dest = {
let poss = targets.otherwise();
// If the fallthrough on the parent is trivially unreachable, we can let the
// children choose the destination
if bbs[poss].statements.len() == 0
&& bbs[poss].terminator().kind == TerminatorKind::Unreachable
{
None
} else {
Some(poss)
}
};
if !bbs[targets.otherwise()].is_empty_unreachable() {
// Someone could write code like this:
// ```rust
// let Q = val;
// if discriminant(P) == otherwise {
// let ptr = &mut Q as *mut _ as *mut u8;
// // It may be difficult for us to effectively determine whether values are valid.
// // Invalid values can come from all sorts of corners.
// unsafe { *ptr = 10; }
// }
//
// match P {
// A => match Q {
// A => {
// // code
// }
// _ => {
// // don't use Q
// }
// }
// _ => {
// // don't use Q
// }
// };
// ```
//
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
// invalid value, which is UB.
// In order to fix this, **we would either need to show that the discriminant computation of
// `place` is computed in all branches**.
// FIXME(#95162) For the moment, we adopt a conservative approach and
// consider only the `otherwise` branch has no statements and an unreachable terminator.
return None;
}
let (_, child) = targets.iter().next()?;
let child_terminator = &bbs[child].terminator();
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
Expand All @@ -344,13 +284,7 @@ fn evaluate_candidate<'tcx>(
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
return None;
};
let destination = parent_dest.unwrap_or(child_targets.otherwise());

// Verify that the optimization is legal in general
// We can hoist evaluating the child discriminant out of the branch
if !may_hoist(tcx, body, *child_place) {
return None;
}
let destination = child_targets.otherwise();

// Verify that the optimization is legal for each branch
for (value, child) in targets.iter() {
Expand Down Expand Up @@ -411,5 +345,5 @@ fn verify_candidate_branch<'tcx>(
if let Some(_) = iter.next() {
return false;
}
return true;
true
}
26 changes: 26 additions & 0 deletions tests/codegen/enum/enum-early-otherwise-branch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//@ compile-flags: -O
//@ min-llvm-version: 18

#![crate_type = "lib"]

pub enum Enum {
A(u32),
B(u32),
C(u32),
}

#[no_mangle]
pub fn foo(lhs: &Enum, rhs: &Enum) -> bool {
// CHECK-LABEL: define{{.*}}i1 @foo(
// CHECK-NOT: switch
// CHECK-NOT: br
// CHECK: [[SELECT:%.*]] = select
// CHECK-NEXT: ret i1 [[SELECT]]
// CHECK-NEXT: }
match (lhs, rhs) {
(Enum::A(lhs), Enum::A(rhs)) => lhs == rhs,
(Enum::B(lhs), Enum::B(rhs)) => lhs == rhs,
(Enum::C(lhs), Enum::C(rhs)) => lhs == rhs,
_ => false,
}
}
39 changes: 13 additions & 26 deletions tests/mir-opt/early_otherwise_branch.opt1.EarlyOtherwiseBranch.diff
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
let mut _7: isize;
let _8: u32;
let _9: u32;
+ let mut _10: isize;
+ let mut _11: bool;
scope 1 {
debug a => _8;
debug b => _9;
Expand All @@ -29,48 +27,37 @@
StorageDead(_5);
StorageDead(_4);
_7 = discriminant((_3.0: std::option::Option<u32>));
- switchInt(move _7) -> [1: bb2, otherwise: bb1];
+ StorageLive(_10);
+ _10 = discriminant((_3.1: std::option::Option<u32>));
+ StorageLive(_11);
+ _11 = Ne(_7, move _10);
+ StorageDead(_10);
+ switchInt(move _11) -> [0: bb4, otherwise: bb1];
switchInt(move _7) -> [1: bb2, 0: bb1, otherwise: bb5];
}

bb1: {
+ StorageDead(_11);
_0 = const 1_u32;
- goto -> bb4;
+ goto -> bb3;
goto -> bb4;
}

bb2: {
- _6 = discriminant((_3.1: std::option::Option<u32>));
- switchInt(move _6) -> [1: bb3, otherwise: bb1];
- }
-
- bb3: {
_6 = discriminant((_3.1: std::option::Option<u32>));
switchInt(move _6) -> [1: bb3, 0: bb1, otherwise: bb5];
}
bb3: {
StorageLive(_8);
_8 = (((_3.0: std::option::Option<u32>) as Some).0: u32);
StorageLive(_9);
_9 = (((_3.1: std::option::Option<u32>) as Some).0: u32);
_0 = const 0_u32;
StorageDead(_9);
StorageDead(_8);
- goto -> bb4;
+ goto -> bb3;
goto -> bb4;
}

- bb4: {
+ bb3: {
bb4: {
StorageDead(_3);
return;
+ }
+
+ bb4: {
+ StorageDead(_11);
+ switchInt(_7) -> [1: bb2, otherwise: bb1];
}

bb5: {
unreachable;
}
}

Loading
Loading