@@ -24,10 +24,13 @@ struct PeepholeOptimizeImpl {
24
24
PeepholeOptimizeImpl (
25
25
const std::shared_ptr<Graph>& graph,
26
26
bool disable_shape_peepholes)
27
- : graph_(graph), shape_peepholes_(!disable_shape_peepholes) {
28
- run (graph->block ());
29
- PeepholeOptimizeListIdioms (graph);
30
- PeepholeOptimizeAliasSensitive (graph);
27
+ : graph_(graph), shape_peepholes_(!disable_shape_peepholes) {}
28
+
29
+ bool run () {
30
+ bool changed = optimizeBlock (graph_->block ());
31
+ changed |= PeepholeOptimizeListIdioms (graph_);
32
+ changed |= PeepholeOptimizeAliasSensitive (graph_);
33
+ return changed;
31
34
}
32
35
33
36
// The intent for this optimization pass is to catch all of the small, easy to
@@ -39,12 +42,13 @@ struct PeepholeOptimizeImpl {
39
42
//
40
43
// TODO: Decide what kind of fixed point strategy we will have
41
44
//
42
- void run (Block* block) {
45
+ bool optimizeBlock (Block* block) {
46
+ bool changed = false ;
43
47
for (auto it = block->nodes ().begin (); it != block->nodes ().end (); ++it) {
44
48
auto * node = *it;
45
49
46
50
for (Block* sub_block : node->blocks ()) {
47
- run (sub_block);
51
+ changed |= optimizeBlock (sub_block);
48
52
}
49
53
50
54
if (node->kind () != prim::Constant) {
@@ -55,6 +59,7 @@ struct PeepholeOptimizeImpl {
55
59
for (Value* output : node->outputs ()) {
56
60
if (output->type ()->cast <NoneType>()) {
57
61
output->replaceAllUsesWith (graph_->insertConstant (IValue ()));
62
+ changed = true ;
58
63
}
59
64
}
60
65
}
@@ -71,6 +76,7 @@ struct PeepholeOptimizeImpl {
71
76
" (x._grad_sum_to_size(x, None) == x) is replaced with " ,
72
77
node->input (0 )->debugName ());
73
78
node->output ()->replaceAllUsesWith (node->input (0 ));
79
+ changed = true ;
74
80
} else {
75
81
auto uses = node->output ()->uses ();
76
82
for (Use u : uses) {
@@ -82,6 +88,7 @@ struct PeepholeOptimizeImpl {
82
88
" (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with " ,
83
89
node->inputs ().at (0 )->debugName ());
84
90
u.user ->replaceInput (0 , node->inputs ().at (0 ));
91
+ changed = true ;
85
92
}
86
93
}
87
94
}
@@ -102,6 +109,7 @@ struct PeepholeOptimizeImpl {
102
109
" (x.expand(x.size()) == x) is replaced with " ,
103
110
node->namedInput (attr::self)->debugName ());
104
111
node->output ()->replaceAllUsesWith (node->namedInput (attr::self));
112
+ changed = true ;
105
113
}
106
114
}
107
115
} else if (node->matches (" aten::t(Tensor self) -> Tensor" )) {
@@ -113,6 +121,7 @@ struct PeepholeOptimizeImpl {
113
121
" (x.t().t() == x) is replaced with " ,
114
122
input_node->input ()->debugName ());
115
123
node->output ()->replaceAllUsesWith (input_node->input ());
124
+ changed = true ;
116
125
}
117
126
} else if (
118
127
node->matches (" aten::type_as(Tensor self, Tensor other) -> Tensor" ) &&
@@ -127,6 +136,7 @@ struct PeepholeOptimizeImpl {
127
136
" (x.type_as(y) == x) is replaced with " ,
128
137
node->input (0 )->debugName ());
129
138
node->output ()->replaceAllUsesWith (node->input (0 ));
139
+ changed = true ;
130
140
}
131
141
} else if (
132
142
node->kind () == aten::Float || node->kind () == aten::Int ||
@@ -140,6 +150,7 @@ struct PeepholeOptimizeImpl {
140
150
" (x.NumToTensor().TensorToNum() == x.NumToTensor()) is replaced with " ,
141
151
node->input ()->debugName ());
142
152
node->output ()->replaceAllUsesWith (input_node->input ());
153
+ changed = true ;
143
154
}
144
155
} else if (
145
156
node->matches (" aten::size(Tensor self) -> int[]" ) &&
@@ -154,6 +165,7 @@ struct PeepholeOptimizeImpl {
154
165
IValue ival (sizes);
155
166
auto const_sizes_val = node->owningGraph ()->insertConstant (ival);
156
167
node->output ()->replaceAllUsesWith (const_sizes_val);
168
+ changed = true ;
157
169
}
158
170
}
159
171
} else if (
@@ -174,6 +186,7 @@ struct PeepholeOptimizeImpl {
174
186
IValue ival (*ptt->sizes ()[norm_index]);
175
187
auto const_sizes_val = node->owningGraph ()->insertConstant (ival);
176
188
node->output ()->replaceAllUsesWith (const_sizes_val);
189
+ changed = true ;
177
190
}
178
191
}
179
192
}
@@ -187,6 +200,7 @@ struct PeepholeOptimizeImpl {
187
200
IValue ival (at::isFloatingType (dtype));
188
201
auto new_constant = node->owningGraph ()->insertConstant (ival);
189
202
node->output ()->replaceAllUsesWith (new_constant);
203
+ changed = true ;
190
204
}
191
205
} else if (
192
206
node->matches (" aten::is_complex(Tensor self) -> bool" ) &&
@@ -220,6 +234,7 @@ struct PeepholeOptimizeImpl {
220
234
" (True or False) with " ,
221
235
n.cond ()->debugName ());
222
236
n.outputs ().at (i)->replaceAllUsesWith (n.cond ());
237
+ changed = true ;
223
238
}
224
239
}
225
240
} else if (
@@ -240,6 +255,7 @@ struct PeepholeOptimizeImpl {
240
255
GRAPH_UPDATE (
241
256
" Folding " , getHeader (node), " to " , output->debugName ());
242
257
node->output ()->replaceAllUsesWith (output);
258
+ changed = true ;
243
259
}
244
260
}
245
261
} else if (
@@ -255,6 +271,7 @@ struct PeepholeOptimizeImpl {
255
271
node->input (),
256
272
" can't be optional" );
257
273
node->output ()->replaceAllUsesWith (node->input ());
274
+ changed = true ;
258
275
}
259
276
} else if (node->kind () == prim::unchecked_cast) {
260
277
// unchecked_cast is not generated for tensor properties, so we are not
@@ -267,6 +284,7 @@ struct PeepholeOptimizeImpl {
267
284
getHeader (node),
268
285
" as input type subtypes output type" );
269
286
node->output ()->replaceAllUsesWith (node->input ());
287
+ changed = true ;
270
288
}
271
289
} else if (
272
290
node->matches (" prim::dtype(Tensor a) -> int" ) && shape_peepholes_) {
@@ -281,6 +299,7 @@ struct PeepholeOptimizeImpl {
281
299
" with a type constant " ,
282
300
output->debugName ());
283
301
node->output ()->replaceAllUsesWith (output);
302
+ changed = true ;
284
303
}
285
304
} else if (
286
305
node->matches (" prim::device(Tensor a) -> Device" ) &&
@@ -295,6 +314,7 @@ struct PeepholeOptimizeImpl {
295
314
" with a device constant " ,
296
315
output->debugName ());
297
316
node->output ()->replaceAllUsesWith (output);
317
+ changed = true ;
298
318
}
299
319
} else if (
300
320
node->matches (" aten::dim(Tensor self) -> int" ) && shape_peepholes_) {
@@ -309,6 +329,7 @@ struct PeepholeOptimizeImpl {
309
329
" with a \" dim\" constant " ,
310
330
output->debugName ());
311
331
node->output ()->replaceAllUsesWith (output);
332
+ changed = true ;
312
333
}
313
334
} else if (
314
335
node->matches (" prim::is_cuda(Tensor a) -> bool" ) &&
@@ -324,6 +345,7 @@ struct PeepholeOptimizeImpl {
324
345
" with a is_cuda constant " ,
325
346
output->debugName ());
326
347
node->output ()->replaceAllUsesWith (output);
348
+ changed = true ;
327
349
}
328
350
}
329
351
@@ -333,6 +355,7 @@ struct PeepholeOptimizeImpl {
333
355
// the limited speedup of these optimizations
334
356
// runAliasingSensitivePeepholeTransformations(node);
335
357
}
358
+ return changed;
336
359
}
337
360
338
361
// if either the inputs or outputs of an op alias graph's inputs or
@@ -345,10 +368,11 @@ struct PeepholeOptimizeImpl {
345
368
// s += x
346
369
// return s
347
370
//
348
- void runAliasingSensitivePeepholeTransformations (Node* node) {
371
+ bool runAliasingSensitivePeepholeTransformations (Node* node) {
349
372
// this code is not currently enabled, see [aliasing sensitive
350
373
// optimizations]
351
374
TORCH_INTERNAL_ASSERT (false );
375
+ bool changed = false ;
352
376
if (node->matches (
353
377
" aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor" ,
354
378
/* const_inputs=*/ {attr::alpha, attr::other}) ||
@@ -363,6 +387,7 @@ struct PeepholeOptimizeImpl {
363
387
" (x + 0 == x - 0 == x) is replaced with " ,
364
388
node->input (0 )->debugName ());
365
389
node->output ()->replaceAllUsesWith (node->input (0 ));
390
+ changed = true ;
366
391
}
367
392
} else if (
368
393
node->matches (
@@ -378,16 +403,19 @@ struct PeepholeOptimizeImpl {
378
403
" (x * 1 == x / 1 == x) is replaced with " ,
379
404
node->input (0 )->debugName ());
380
405
node->output ()->replaceAllUsesWith (node->input (0 ));
406
+ changed = true ;
381
407
}
382
408
}
409
+ return changed;
383
410
}
384
411
385
412
private:
386
413
std::shared_ptr<Graph> graph_;
387
414
bool shape_peepholes_;
388
415
};
389
416
390
- void FuseAddMM (Block* block) {
417
+ bool FuseAddMM (Block* block) {
418
+ bool changed = false ;
391
419
for (Node* node : block->nodes ()) {
392
420
// XXX: remember that if you want to simplify an expression by combining
393
421
// multiple nodes into a different one, then you need to check that they
@@ -488,15 +516,17 @@ void FuseAddMM(Block* block) {
488
516
" into " ,
489
517
addmm_value->debugName ());
490
518
node->output ()->replaceAllUsesWith (addmm_value);
519
+ changed = true ;
491
520
continue ;
492
521
}
493
522
}
494
523
}
495
524
}
496
525
for (Block* b : node->blocks ()) {
497
- FuseAddMM (b);
526
+ changed |= FuseAddMM (b);
498
527
}
499
528
}
529
+ return changed;
500
530
}
501
531
502
532
// FuseAddMM is a separate pass from peephole optimize because it is currently
@@ -506,17 +536,21 @@ void FuseAddMM(Block* block) {
506
536
// since after ONNX translation we would see redundant Gemm ops with sub-optimal
507
537
// inputs. This flag is exposed so that ONNX export can pass `true` to get the
508
538
// fused behavior, but normal JIT peephole optimization is left alone.
509
- void FuseAddMM (const std::shared_ptr<Graph>& graph) {
510
- FuseAddMM (graph->block ());
539
+ bool FuseAddMM (const std::shared_ptr<Graph>& graph) {
540
+ return FuseAddMM (graph->block ());
511
541
}
512
542
513
- void PeepholeOptimize (
543
+ bool PeepholeOptimize (
514
544
const std::shared_ptr<Graph>& graph,
515
545
bool addmm_fusion_enabled) {
516
546
PeepholeOptimizeImpl peephole (graph, addmm_fusion_enabled);
547
+ bool changed = peephole.run ();
517
548
GRAPH_DUMP (" After PeepholeOptimize: " , graph);
518
549
// Eliminate dead code created by any peephole passes we've just done
519
- EliminateDeadCode (graph->block ());
550
+ if (changed) {
551
+ EliminateDeadCode (graph->block ());
552
+ }
553
+ return changed;
520
554
}
521
555
522
556
} // namespace jit
0 commit comments