1
1
use rustc_index:: IndexVec ;
2
2
use rustc_middle:: mir:: * ;
3
- use rustc_middle:: ty:: { ParamEnv , Ty , TyCtxt } ;
3
+ use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
4
4
use std:: iter;
5
5
6
6
use super :: simplify:: simplify_cfg;
@@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
38
38
should_cleanup = true ;
39
39
continue ;
40
40
}
41
+ if SimplifyToExp :: default ( ) . simplify ( tcx, & mut body. local_decls , bbs, bb_idx, param_env)
42
+ {
43
+ should_cleanup = true ;
44
+ continue ;
45
+ }
41
46
}
42
47
43
48
if should_cleanup {
@@ -47,8 +52,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
47
52
}
48
53
49
54
trait SimplifyMatch < ' tcx > {
55
+ /// Simplifies a match statement, returning true if the simplification succeeds, false otherwise.
56
+ /// Generic code is written here, and we generally don't need a custom implementation.
50
57
fn simplify (
51
- & self ,
58
+ & mut self ,
52
59
tcx : TyCtxt < ' tcx > ,
53
60
local_decls : & mut IndexVec < Local , LocalDecl < ' tcx > > ,
54
61
bbs : & mut IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
@@ -72,9 +79,7 @@ trait SimplifyMatch<'tcx> {
72
79
let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
73
80
let discr_local = local_decls. push ( LocalDecl :: new ( discr_ty, source_info. span ) ) ;
74
81
75
- // We already checked that first and second are different blocks,
76
- // and bb_idx has a different terminator from both of them.
77
- let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local. clone ( ) , discr_ty) ;
82
+ let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local, discr_ty) ;
78
83
let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
79
84
let ( from, first) = bbs. pick2_mut ( switch_bb_idx, first) ;
80
85
from. statements
@@ -90,8 +95,11 @@ trait SimplifyMatch<'tcx> {
90
95
true
91
96
}
92
97
98
+ /// Check that the BBs to be simplified satisfies all distinct and
99
+ /// that the terminator are the same.
100
+ /// There are also conditions for different ways of simplification.
93
101
fn can_simplify (
94
- & self ,
102
+ & mut self ,
95
103
tcx : TyCtxt < ' tcx > ,
96
104
targets : & SwitchTargets ,
97
105
param_env : ParamEnv < ' tcx > ,
@@ -144,7 +152,7 @@ struct SimplifyToIf;
144
152
/// ```
145
153
impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToIf {
146
154
fn can_simplify (
147
- & self ,
155
+ & mut self ,
148
156
tcx : TyCtxt < ' tcx > ,
149
157
targets : & SwitchTargets ,
150
158
param_env : ParamEnv < ' tcx > ,
@@ -250,3 +258,211 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
250
258
new_stmts. collect ( )
251
259
}
252
260
}
261
+
262
+ #[ derive( Default ) ]
263
+ struct SimplifyToExp {
264
+ transfrom_types : Vec < TransfromType > ,
265
+ }
266
+
267
+ #[ derive( Clone , Copy ) ]
268
+ enum CompareType < ' tcx , ' a > {
269
+ Same ( & ' a StatementKind < ' tcx > ) ,
270
+ Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
271
+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
272
+ }
273
+
274
+ enum TransfromType {
275
+ Same ,
276
+ Eq ,
277
+ Discr ,
278
+ }
279
+
280
+ impl From < CompareType < ' _ , ' _ > > for TransfromType {
281
+ fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
282
+ match compare_type {
283
+ CompareType :: Same ( _) => TransfromType :: Same ,
284
+ CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
285
+ CompareType :: Discr ( _, _) => TransfromType :: Discr ,
286
+ }
287
+ }
288
+ }
289
+
290
+ /// If we find that the value of match is the same as the assignment,
291
+ /// merge a target block statements into the source block,
292
+ /// using cast to transform different integer types.
293
+ ///
294
+ /// For example:
295
+ ///
296
+ /// ```ignore (MIR)
297
+ /// bb0: {
298
+ /// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
299
+ /// }
300
+ ///
301
+ /// bb1: {
302
+ /// unreachable;
303
+ /// }
304
+ ///
305
+ /// bb2: {
306
+ /// _0 = const 1_i16;
307
+ /// goto -> bb5;
308
+ /// }
309
+ ///
310
+ /// bb3: {
311
+ /// _0 = const 2_i16;
312
+ /// goto -> bb5;
313
+ /// }
314
+ ///
315
+ /// bb4: {
316
+ /// _0 = const 3_i16;
317
+ /// goto -> bb5;
318
+ /// }
319
+ /// ```
320
+ ///
321
+ /// into:
322
+ ///
323
+ /// ```ignore (MIR)
324
+ /// bb0: {
325
+ /// _0 = _3 as i16 (IntToInt);
326
+ /// goto -> bb5;
327
+ /// }
328
+ /// ```
329
+ impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToExp {
330
+ fn can_simplify (
331
+ & mut self ,
332
+ tcx : TyCtxt < ' tcx > ,
333
+ targets : & SwitchTargets ,
334
+ param_env : ParamEnv < ' tcx > ,
335
+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
336
+ ) -> bool {
337
+ if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
338
+ return false ;
339
+ }
340
+ // We require that the possible target blocks all be distinct.
341
+ if !targets. is_distinct ( ) {
342
+ return false ;
343
+ }
344
+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
345
+ return false ;
346
+ }
347
+ let mut target_iter = targets. iter ( ) ;
348
+ let ( first_val, first_target) = target_iter. next ( ) . unwrap ( ) ;
349
+ let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
350
+ // Check that destinations are identical, and if not, then don't optimize this block
351
+ if !targets
352
+ . iter ( )
353
+ . all ( |( _, other_target) | first_terminator_kind == & bbs[ other_target] . terminator ( ) . kind )
354
+ {
355
+ return false ;
356
+ }
357
+
358
+ let first_stmts = & bbs[ first_target] . statements ;
359
+ let ( second_val, second_target) = target_iter. next ( ) . unwrap ( ) ;
360
+ let second_stmts = & bbs[ second_target] . statements ;
361
+ if first_stmts. len ( ) != second_stmts. len ( ) {
362
+ return false ;
363
+ }
364
+
365
+ let mut compare_types = Vec :: new ( ) ;
366
+ for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
367
+ let compare_type = match ( & f. kind , & s. kind ) {
368
+ // If two statements are exactly the same, we can optimize.
369
+ ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
370
+
371
+ // If two statements are assignments with the match values to the same place, we can optimize.
372
+ (
373
+ StatementKind :: Assign ( box ( lhs_f, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
374
+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
375
+ ) if lhs_f == lhs_s
376
+ && f_c. const_ . ty ( ) == s_c. const_ . ty ( )
377
+ && f_c. const_ . ty ( ) . is_integral ( ) =>
378
+ {
379
+ match (
380
+ f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
381
+ s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
382
+ ) {
383
+ ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
384
+ ( Some ( f) , Some ( s) )
385
+ if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
386
+ && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
387
+ {
388
+ CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
389
+ }
390
+ _ => return false ,
391
+ }
392
+ }
393
+
394
+ // Otherwise we cannot optimize. Try another block.
395
+ _ => return false ,
396
+ } ;
397
+ compare_types. push ( compare_type) ;
398
+ }
399
+
400
+ // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
401
+ for ( other_val, other_target) in target_iter {
402
+ let other_stmts = & bbs[ other_target] . statements ;
403
+ if compare_types. len ( ) != other_stmts. len ( ) {
404
+ return false ;
405
+ }
406
+ for ( f, s) in iter:: zip ( & compare_types, other_stmts) {
407
+ match ( * f, & s. kind ) {
408
+ ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
409
+ (
410
+ CompareType :: Eq ( lhs_f, f_ty, val) ,
411
+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
412
+ ) if lhs_f == lhs_s
413
+ && s_c. const_ . ty ( ) == f_ty
414
+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
415
+ (
416
+ CompareType :: Discr ( lhs_f, f_ty) ,
417
+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
418
+ ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
419
+ let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
420
+ return false ;
421
+ } ;
422
+ if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
423
+ return false ;
424
+ }
425
+ }
426
+ _ => return false ,
427
+ }
428
+ }
429
+ }
430
+ self . transfrom_types = compare_types. into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
431
+ true
432
+ }
433
+
434
+ fn new_stmts (
435
+ & self ,
436
+ _tcx : TyCtxt < ' tcx > ,
437
+ targets : & SwitchTargets ,
438
+ _param_env : ParamEnv < ' tcx > ,
439
+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
440
+ discr_local : Local ,
441
+ discr_ty : Ty < ' tcx > ,
442
+ ) -> Vec < Statement < ' tcx > > {
443
+ let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
444
+ let first = & bbs[ first] ;
445
+
446
+ let new_stmts =
447
+ iter:: zip ( & self . transfrom_types , & first. statements ) . map ( |( t, s) | match ( t, & s. kind ) {
448
+ ( TransfromType :: Same , _) | ( TransfromType :: Eq , _) => ( * s) . clone ( ) ,
449
+ (
450
+ TransfromType :: Discr ,
451
+ StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
452
+ ) => {
453
+ let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
454
+ let r_val = if f_c. const_ . ty ( ) == discr_ty {
455
+ Rvalue :: Use ( operand)
456
+ } else {
457
+ Rvalue :: Cast ( CastKind :: IntToInt , operand, f_c. const_ . ty ( ) )
458
+ } ;
459
+ Statement {
460
+ source_info : s. source_info ,
461
+ kind : StatementKind :: Assign ( Box :: new ( ( * lhs, r_val) ) ) ,
462
+ }
463
+ }
464
+ _ => unreachable ! ( ) ,
465
+ } ) ;
466
+ new_stmts. collect ( )
467
+ }
468
+ }
0 commit comments