Skip to content

Simplify match based on the cast result of IntToInt #127324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4346,6 +4346,7 @@ dependencies = [
"rustc_span",
"rustc_target",
"rustc_trait_selection",
"rustc_type_ir",
"smallvec",
"tracing",
]
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir_transform/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ rustc_session = { path = "../rustc_session" }
rustc_span = { path = "../rustc_span" }
rustc_target = { path = "../rustc_target" }
rustc_trait_selection = { path = "../rustc_trait_selection" }
rustc_type_ir = { path = "../rustc_type_ir" }
smallvec = { version = "1.8.1", features = ["union", "may_dangle"] }
tracing = "0.1"
# tidy-alphabetical-end
146 changes: 80 additions & 66 deletions compiler/rustc_mir_transform/src/match_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use std::iter;
use rustc_index::IndexSlice;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
use rustc_target::abi::Size;
use rustc_target::abi::Integer;
use rustc_type_ir::TyKind::*;

use super::simplify::simplify_cfg;

Expand Down Expand Up @@ -42,10 +44,7 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
should_cleanup = true;
continue;
}
// unsound: https://github.com/rust-lang/rust/issues/124150
if tcx.sess.opts.unstable_opts.unsound_mir_opts
&& SimplifyToExp::default().simplify(tcx, body, bb_idx, param_env).is_some()
{
if SimplifyToExp::default().simplify(tcx, body, bb_idx, param_env).is_some() {
should_cleanup = true;
continue;
}
Expand Down Expand Up @@ -264,33 +263,56 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
}
}

/// Check if the cast constant using `IntToInt` is equal to the target constant.
fn can_cast(
tcx: TyCtxt<'_>,
src_val: impl Into<u128>,
src_layout: TyAndLayout<'_>,
cast_ty: Ty<'_>,
target_scalar: ScalarInt,
) -> bool {
let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
let v = match src_layout.ty.kind() {
Uint(_) => from_scalar.to_uint(src_layout.size),
Int(_) => from_scalar.to_int(src_layout.size) as u128,
_ => unreachable!("invalid int"),
};
let size = match *cast_ty.kind() {
Int(t) => Integer::from_int_ty(&tcx, t).size(),
Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
_ => unreachable!("invalid int"),
};
let v = size.truncate(v);
let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap();
cast_scalar == target_scalar
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code does not exactly match what the interpreter does in cast_from_int_like. Should we reuse similar code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.
When I read this part of the code earlier, I tried to call it directly, which seemed a bit difficult. I copied part of the code. Maybe I should put such code in rustc_middle?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rustc_const_eval is accessible from this crate, so putting is there is enough.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, I have squashed them into two commits.


#[derive(Default)]
struct SimplifyToExp {
transfrom_types: Vec<TransfromType>,
transfrom_kinds: Vec<TransfromKind>,
}

#[derive(Clone, Copy)]
enum CompareType<'tcx, 'a> {
enum ExpectedTransformKind<'tcx, 'a> {
/// Identical statements.
Same(&'a StatementKind<'tcx>),
/// Assignment statements have the same value.
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
SameByEq { place: &'a Place<'tcx>, ty: Ty<'tcx>, scalar: ScalarInt },
/// Enum variant comparison type.
Discr { place: &'a Place<'tcx>, ty: Ty<'tcx>, is_signed: bool },
Cast { place: &'a Place<'tcx>, ty: Ty<'tcx> },
}

enum TransfromType {
enum TransfromKind {
Same,
Eq,
Discr,
Cast,
}

impl From<CompareType<'_, '_>> for TransfromType {
fn from(compare_type: CompareType<'_, '_>) -> Self {
impl From<ExpectedTransformKind<'_, '_>> for TransfromKind {
fn from(compare_type: ExpectedTransformKind<'_, '_>) -> Self {
match compare_type {
CompareType::Same(_) => TransfromType::Same,
CompareType::Eq(_, _, _) => TransfromType::Eq,
CompareType::Discr { .. } => TransfromType::Discr,
ExpectedTransformKind::Same(_) => TransfromKind::Same,
ExpectedTransformKind::SameByEq { .. } => TransfromKind::Same,
ExpectedTransformKind::Cast { .. } => TransfromKind::Cast,
}
}
}
Expand Down Expand Up @@ -354,7 +376,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
return None;
}
let mut target_iter = targets.iter();
let (first_val, first_target) = target_iter.next().unwrap();
let (first_case_val, first_target) = target_iter.next().unwrap();
let first_terminator_kind = &bbs[first_target].terminator().kind;
// Check that destinations are identical, and if not, then don't optimize this block
if !targets
Expand All @@ -364,24 +386,20 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
return None;
}

let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
let discr_layout = tcx.layout_of(param_env.and(discr_ty)).unwrap();
let first_stmts = &bbs[first_target].statements;
let (second_val, second_target) = target_iter.next().unwrap();
let (second_case_val, second_target) = target_iter.next().unwrap();
let second_stmts = &bbs[second_target].statements;
if first_stmts.len() != second_stmts.len() {
return None;
}

fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool {
l.to_bits_unchecked() == ScalarInt::try_from_uint(r, size).unwrap().to_bits_unchecked()
}

// We first compare the two branches, and then the other branches need to fulfill the same conditions.
let mut compare_types = Vec::new();
let mut expected_transform_kinds = Vec::new();
for (f, s) in iter::zip(first_stmts, second_stmts) {
let compare_type = match (&f.kind, &s.kind) {
// If two statements are exactly the same, we can optimize.
(f_s, s_s) if f_s == s_s => CompareType::Same(f_s),
(f_s, s_s) if f_s == s_s => ExpectedTransformKind::Same(f_s),

// If two statements are assignments with the match values to the same place, we can optimize.
(
Expand All @@ -395,22 +413,29 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
f_c.const_.try_eval_scalar_int(tcx, param_env),
s_c.const_.try_eval_scalar_int(tcx, param_env),
) {
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
// Enum variants can also be simplified to an assignment statement if their values are equal.
// We need to consider both unsigned and signed scenarios here.
(Some(f), Some(s)) if f == s => ExpectedTransformKind::SameByEq {
place: lhs_f,
ty: f_c.const_.ty(),
scalar: f,
},
// Enum variants can also be simplified to an assignment statement,
// if we can use `IntToInt` cast to get an equal value.
(Some(f), Some(s))
if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
&& int_equal(f, first_val, discr_size)
&& int_equal(s, second_val, discr_size))
|| (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
&& Some(s)
== ScalarInt::try_from_uint(second_val, s.size())) =>
if (can_cast(
tcx,
first_case_val,
discr_layout,
f_c.const_.ty(),
f,
) && can_cast(
tcx,
second_case_val,
discr_layout,
f_c.const_.ty(),
s,
)) =>
{
CompareType::Discr {
place: lhs_f,
ty: f_c.const_.ty(),
is_signed: f_c.const_.ty().is_signed() || discr_ty.is_signed(),
}
ExpectedTransformKind::Cast { place: lhs_f, ty: f_c.const_.ty() }
}
_ => {
return None;
Expand All @@ -421,47 +446,36 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
// Otherwise we cannot optimize. Try another block.
_ => return None,
};
compare_types.push(compare_type);
expected_transform_kinds.push(compare_type);
}

// All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
for (other_val, other_target) in target_iter {
let other_stmts = &bbs[other_target].statements;
if compare_types.len() != other_stmts.len() {
if expected_transform_kinds.len() != other_stmts.len() {
return None;
}
for (f, s) in iter::zip(&compare_types, other_stmts) {
for (f, s) in iter::zip(&expected_transform_kinds, other_stmts) {
match (*f, &s.kind) {
(CompareType::Same(f_s), s_s) if f_s == s_s => {}
(ExpectedTransformKind::Same(f_s), s_s) if f_s == s_s => {}
(
CompareType::Eq(lhs_f, f_ty, val),
ExpectedTransformKind::SameByEq { place: lhs_f, ty: f_ty, scalar },
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
) if lhs_f == lhs_s
&& s_c.const_.ty() == f_ty
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(scalar) => {}
(
CompareType::Discr { place: lhs_f, ty: f_ty, is_signed },
ExpectedTransformKind::Cast { place: lhs_f, ty: f_ty },
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
return None;
};
if is_signed
&& s_c.const_.ty().is_signed()
&& int_equal(f, other_val, discr_size)
{
continue;
}
if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
continue;
}
return None;
}
) if let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env)
&& lhs_f == lhs_s
&& s_c.const_.ty() == f_ty
&& can_cast(tcx, other_val, discr_layout, f_ty, f) => {}
_ => return None,
}
}
}
self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect();
self.transfrom_kinds = expected_transform_kinds.into_iter().map(|c| c.into()).collect();
Some(())
}

Expand All @@ -479,13 +493,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
let (_, first) = targets.iter().next().unwrap();
let first = &bbs[first];

for (t, s) in iter::zip(&self.transfrom_types, &first.statements) {
for (t, s) in iter::zip(&self.transfrom_kinds, &first.statements) {
match (t, &s.kind) {
(TransfromType::Same, _) | (TransfromType::Eq, _) => {
(TransfromKind::Same, _) => {
patch.add_statement(parent_end, s.kind.clone());
}
(
TransfromType::Discr,
TransfromKind::Cast,
StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
) => {
let operand = Operand::Copy(Place::from(discr_local));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,42 @@
debug i => _1;
let mut _0: u128;
let mut _2: i128;
+ let mut _3: i128;

bb0: {
_2 = discriminant(_1);
switchInt(move _2) -> [1: bb5, 2: bb4, 3: bb3, 340282366920938463463374607431768211455: bb2, otherwise: bb1];
}

bb1: {
unreachable;
}

bb2: {
_0 = const core::num::<impl u128>::MAX;
goto -> bb6;
}

bb3: {
_0 = const 3_u128;
goto -> bb6;
}

bb4: {
_0 = const 2_u128;
goto -> bb6;
}

bb5: {
_0 = const 1_u128;
goto -> bb6;
}

bb6: {
- switchInt(move _2) -> [1: bb5, 2: bb4, 3: bb3, 340282366920938463463374607431768211455: bb2, otherwise: bb1];
- }
-
- bb1: {
- unreachable;
- }
-
- bb2: {
- _0 = const core::num::<impl u128>::MAX;
- goto -> bb6;
- }
-
- bb3: {
- _0 = const 3_u128;
- goto -> bb6;
- }
-
- bb4: {
- _0 = const 2_u128;
- goto -> bb6;
- }
-
- bb5: {
- _0 = const 1_u128;
- goto -> bb6;
- }
-
- bb6: {
+ StorageLive(_3);
+ _3 = move _2;
+ _0 = _3 as u128 (IntToInt);
+ StorageDead(_3);
return;
}
}
Expand Down

This file was deleted.

Loading
Loading