Skip to content

Commit 1f061f4

Browse files
committed
Transforms match into an assignment statement
1 parent 7af7458 commit 1f061f4

9 files changed

+370
-117
lines changed

compiler/rustc_middle/src/mir/terminator.rs

+6
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ impl SwitchTargets {
8585
self.values.push(value);
8686
self.targets.insert(self.targets.len() - 1, bb);
8787
}
88+
89+
/// Returns true if all targets (including the fallback target) are distinct.
90+
#[inline]
91+
pub fn is_distinct(&self) -> bool {
92+
self.targets.iter().collect::<FxHashSet<_>>().len() == self.targets.len()
93+
}
8894
}
8995

9096
pub struct SwitchTargetsIter<'a> {

compiler/rustc_mir_transform/src/match_branches.rs

+223-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use rustc_index::IndexVec;
22
use rustc_middle::mir::*;
3-
use rustc_middle::ty::{ParamEnv, Ty, TyCtxt};
3+
use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
44
use std::iter;
55

66
use super::simplify::simplify_cfg;
@@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
3838
should_cleanup = true;
3939
continue;
4040
}
41+
if SimplifyToExp::default().simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env)
42+
{
43+
should_cleanup = true;
44+
continue;
45+
}
4146
}
4247

4348
if should_cleanup {
@@ -47,8 +52,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
4752
}
4853

4954
trait SimplifyMatch<'tcx> {
55+
/// Simplifies a match statement, returning true if the simplification succeeds, false otherwise.
56+
/// Generic code is written here, and we generally don't need a custom implementation.
5057
fn simplify(
51-
&self,
58+
&mut self,
5259
tcx: TyCtxt<'tcx>,
5360
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
5461
bbs: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
@@ -72,9 +79,7 @@ trait SimplifyMatch<'tcx> {
7279
let source_info = bbs[switch_bb_idx].terminator().source_info;
7380
let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span));
7481

75-
// We already checked that first and second are different blocks,
76-
// and bb_idx has a different terminator from both of them.
77-
let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty);
82+
let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local, discr_ty);
7883
let (_, first) = targets.iter().next().unwrap();
7984
let (from, first) = bbs.pick2_mut(switch_bb_idx, first);
8085
from.statements
@@ -90,8 +95,11 @@ trait SimplifyMatch<'tcx> {
9095
true
9196
}
9297

98+
/// Check that the BBs to be simplified satisfies all distinct and
99+
/// that the terminator are the same.
100+
/// There are also conditions for different ways of simplification.
93101
fn can_simplify(
94-
&self,
102+
&mut self,
95103
tcx: TyCtxt<'tcx>,
96104
targets: &SwitchTargets,
97105
param_env: ParamEnv<'tcx>,
@@ -144,7 +152,7 @@ struct SimplifyToIf;
144152
/// ```
145153
impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
146154
fn can_simplify(
147-
&self,
155+
&mut self,
148156
tcx: TyCtxt<'tcx>,
149157
targets: &SwitchTargets,
150158
param_env: ParamEnv<'tcx>,
@@ -250,3 +258,211 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
250258
new_stmts.collect()
251259
}
252260
}
261+
262+
#[derive(Default)]
263+
struct SimplifyToExp {
264+
transfrom_types: Vec<TransfromType>,
265+
}
266+
267+
#[derive(Clone, Copy)]
268+
enum CompareType<'tcx, 'a> {
269+
Same(&'a StatementKind<'tcx>),
270+
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
271+
Discr(&'a Place<'tcx>, Ty<'tcx>),
272+
}
273+
274+
enum TransfromType {
275+
Same,
276+
Eq,
277+
Discr,
278+
}
279+
280+
impl From<CompareType<'_, '_>> for TransfromType {
281+
fn from(compare_type: CompareType<'_, '_>) -> Self {
282+
match compare_type {
283+
CompareType::Same(_) => TransfromType::Same,
284+
CompareType::Eq(_, _, _) => TransfromType::Eq,
285+
CompareType::Discr(_, _) => TransfromType::Discr,
286+
}
287+
}
288+
}
289+
290+
/// If we find that the value of match is the same as the assignment,
291+
/// merge a target block statements into the source block,
292+
/// using cast to transform different integer types.
293+
///
294+
/// For example:
295+
///
296+
/// ```ignore (MIR)
297+
/// bb0: {
298+
/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
299+
/// }
300+
///
301+
/// bb1: {
302+
/// unreachable;
303+
/// }
304+
///
305+
/// bb2: {
306+
/// _0 = const 1_i16;
307+
/// goto -> bb5;
308+
/// }
309+
///
310+
/// bb3: {
311+
/// _0 = const 2_i16;
312+
/// goto -> bb5;
313+
/// }
314+
///
315+
/// bb4: {
316+
/// _0 = const 3_i16;
317+
/// goto -> bb5;
318+
/// }
319+
/// ```
320+
///
321+
/// into:
322+
///
323+
/// ```ignore (MIR)
324+
/// bb0: {
325+
/// _0 = _3 as i16 (IntToInt);
326+
/// goto -> bb5;
327+
/// }
328+
/// ```
329+
impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
330+
fn can_simplify(
331+
&mut self,
332+
tcx: TyCtxt<'tcx>,
333+
targets: &SwitchTargets,
334+
param_env: ParamEnv<'tcx>,
335+
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
336+
) -> bool {
337+
if targets.iter().len() < 2 || targets.iter().len() > 64 {
338+
return false;
339+
}
340+
// We require that the possible target blocks all be distinct.
341+
if !targets.is_distinct() {
342+
return false;
343+
}
344+
if !bbs[targets.otherwise()].is_empty_unreachable() {
345+
return false;
346+
}
347+
let mut target_iter = targets.iter();
348+
let (first_val, first_target) = target_iter.next().unwrap();
349+
let first_terminator_kind = &bbs[first_target].terminator().kind;
350+
// Check that destinations are identical, and if not, then don't optimize this block
351+
if !targets
352+
.iter()
353+
.all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind)
354+
{
355+
return false;
356+
}
357+
358+
let first_stmts = &bbs[first_target].statements;
359+
let (second_val, second_target) = target_iter.next().unwrap();
360+
let second_stmts = &bbs[second_target].statements;
361+
if first_stmts.len() != second_stmts.len() {
362+
return false;
363+
}
364+
365+
let mut compare_types = Vec::new();
366+
for (f, s) in iter::zip(first_stmts, second_stmts) {
367+
let compare_type = match (&f.kind, &s.kind) {
368+
// If two statements are exactly the same, we can optimize.
369+
(f_s, s_s) if f_s == s_s => CompareType::Same(f_s),
370+
371+
// If two statements are assignments with the match values to the same place, we can optimize.
372+
(
373+
StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
374+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
375+
) if lhs_f == lhs_s
376+
&& f_c.const_.ty() == s_c.const_.ty()
377+
&& f_c.const_.ty().is_integral() =>
378+
{
379+
match (
380+
f_c.const_.try_eval_scalar_int(tcx, param_env),
381+
s_c.const_.try_eval_scalar_int(tcx, param_env),
382+
) {
383+
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
384+
(Some(f), Some(s))
385+
if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
386+
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
387+
{
388+
CompareType::Discr(lhs_f, f_c.const_.ty())
389+
}
390+
_ => return false,
391+
}
392+
}
393+
394+
// Otherwise we cannot optimize. Try another block.
395+
_ => return false,
396+
};
397+
compare_types.push(compare_type);
398+
}
399+
400+
// All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
401+
for (other_val, other_target) in target_iter {
402+
let other_stmts = &bbs[other_target].statements;
403+
if compare_types.len() != other_stmts.len() {
404+
return false;
405+
}
406+
for (f, s) in iter::zip(&compare_types, other_stmts) {
407+
match (*f, &s.kind) {
408+
(CompareType::Same(f_s), s_s) if f_s == s_s => {}
409+
(
410+
CompareType::Eq(lhs_f, f_ty, val),
411+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
412+
) if lhs_f == lhs_s
413+
&& s_c.const_.ty() == f_ty
414+
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
415+
(
416+
CompareType::Discr(lhs_f, f_ty),
417+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
418+
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
419+
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
420+
return false;
421+
};
422+
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
423+
return false;
424+
}
425+
}
426+
_ => return false,
427+
}
428+
}
429+
}
430+
self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect();
431+
true
432+
}
433+
434+
fn new_stmts(
435+
&self,
436+
_tcx: TyCtxt<'tcx>,
437+
targets: &SwitchTargets,
438+
_param_env: ParamEnv<'tcx>,
439+
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
440+
discr_local: Local,
441+
discr_ty: Ty<'tcx>,
442+
) -> Vec<Statement<'tcx>> {
443+
let (_, first) = targets.iter().next().unwrap();
444+
let first = &bbs[first];
445+
446+
let new_stmts =
447+
iter::zip(&self.transfrom_types, &first.statements).map(|(t, s)| match (t, &s.kind) {
448+
(TransfromType::Same, _) | (TransfromType::Eq, _) => (*s).clone(),
449+
(
450+
TransfromType::Discr,
451+
StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
452+
) => {
453+
let operand = Operand::Copy(Place::from(discr_local));
454+
let r_val = if f_c.const_.ty() == discr_ty {
455+
Rvalue::Use(operand)
456+
} else {
457+
Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty())
458+
};
459+
Statement {
460+
source_info: s.source_info,
461+
kind: StatementKind::Assign(Box::new((*lhs, r_val))),
462+
}
463+
}
464+
_ => unreachable!(),
465+
});
466+
new_stmts.collect()
467+
}
468+
}

tests/codegen/match-optimized.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ pub fn exhaustive_match(e: E) -> u8 {
2626
// CHECK-NEXT: store i8 1, ptr %_0, align 1
2727
// CHECK-NEXT: br label %[[EXIT]]
2828
// CHECK: [[C]]:
29-
// CHECK-NEXT: store i8 2, ptr %_0, align 1
29+
// CHECK-NEXT: store i8 3, ptr %_0, align 1
3030
// CHECK-NEXT: br label %[[EXIT]]
3131
match e {
3232
E::A => 0,
3333
E::B => 1,
34-
E::C => 2,
34+
E::C => 3,
3535
}
3636
}
3737

tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff

+33-28
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,42 @@
55
debug i => _1;
66
let mut _0: u128;
77
let mut _2: i128;
8+
+ let mut _3: i128;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1];
12-
}
13-
14-
bb1: {
15-
unreachable;
16-
}
17-
18-
bb2: {
19-
_0 = const core::num::<impl u128>::MAX;
20-
goto -> bb6;
21-
}
22-
23-
bb3: {
24-
_0 = const 1_u128;
25-
goto -> bb6;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_u128;
30-
goto -> bb6;
31-
}
32-
33-
bb5: {
34-
_0 = const 3_u128;
35-
goto -> bb6;
36-
}
37-
38-
bb6: {
12+
- switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1];
13+
- }
14+
-
15+
- bb1: {
16+
- unreachable;
17+
- }
18+
-
19+
- bb2: {
20+
- _0 = const core::num::<impl u128>::MAX;
21+
- goto -> bb6;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const 1_u128;
26+
- goto -> bb6;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_u128;
31+
- goto -> bb6;
32+
- }
33+
-
34+
- bb5: {
35+
- _0 = const 3_u128;
36+
- goto -> bb6;
37+
- }
38+
-
39+
- bb6: {
40+
+ StorageLive(_3);
41+
+ _3 = move _2;
42+
+ _0 = _3 as u128 (IntToInt);
43+
+ StorageDead(_3);
3944
return;
4045
}
4146
}

0 commit comments

Comments
 (0)