Skip to content

Commit 908a74e

Browse files
Peng Wufacebook-github-bot
Peng Wu
authored andcommitted
[Refactoring] make transformations return whether graph is modified (pytorch#54777)
Summary: Pull Request resolved: pytorch#54777 Updated RemoveListMutation, PeepholeOptimizedListIdoms, LoopUnrolling, PeepholeOptimization to return whether graph is modified after transformation, PeepholeAliasSensitivity Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D27412105 fbshipit-source-id: 0c1520bc34f6bd59acd83d98bed58897376eac41
1 parent a37fbf9 commit 908a74e

10 files changed

+141
-62
lines changed

torch/csrc/jit/passes/loop_unrolling.cpp

+20-12
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,11 @@ void replaceLoopCounter(Node* loop) {
162162
body->insertOutput(1, result);
163163
}
164164

165-
void unroll(Node* loop) {
165+
bool unroll(Node* loop) {
166166
Graph* graph = loop->owningGraph();
167167
Block* body = loop->blocks().at(0);
168168
if (!isSmallBlock(body))
169-
return;
169+
return false;
170170

171171
// We will be using a "mutable" counter outside of the loop instead of the
172172
// default one, because this will allow us to share it between the unrolled
@@ -184,7 +184,7 @@ void unroll(Node* loop) {
184184
repeatBody(body, *const_len, dest);
185185
loop->eraseBlock(0);
186186
inlineBody(loop);
187-
return;
187+
return true;
188188
}
189189

190190
WithInsertPoint insert_point_guard{loop};
@@ -212,21 +212,25 @@ void unroll(Node* loop) {
212212
aten::sub,
213213
{iter_count,
214214
graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})}));
215+
216+
return true;
215217
}
216218

217-
void UnrollLoops(Block* block) {
219+
bool UnrollLoops(Block* block) {
220+
bool changed = false;
218221
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
219222
// XXX: unroll might destroy the current node, so we need to pre-increment
220223
// the iterator
221224
Node* node = *it;
222225
++it;
223226
for (Block* subblock : node->blocks()) {
224-
UnrollLoops(subblock);
227+
changed |= UnrollLoops(subblock);
225228
}
226229
if (isForLoop(node)) {
227-
unroll(node);
230+
changed |= unroll(node);
228231
}
229232
}
233+
return changed;
230234
}
231235

232236
} // anonymous namespace
@@ -244,11 +248,12 @@ static void addCondAsOutput(Node* loop) {
244248
cond_output->copyMetadata(loop_view.nextCond());
245249
}
246250

247-
void LoopsPeeler::run(const std::shared_ptr<Graph>& graph) {
251+
bool LoopsPeeler::run(const std::shared_ptr<Graph>& graph) {
248252
GRAPH_DUMP("Before LoopsPeeler", graph);
249253
collectLoops(graph->block());
250254
peelLoops();
251255
GRAPH_DUMP("After LoopsPeeler", graph);
256+
return true;
252257
}
253258

254259
void LoopsPeeler::collectLoop(Node* n) {
@@ -288,7 +293,7 @@ void LoopsPeeler::peelLoops() {
288293
}
289294
}
290295

291-
void PeelProfilingLoops(const std::shared_ptr<Graph>& graph) {
296+
bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph) {
292297
auto peel_predicate = [](Node* n) {
293298
for (auto i : n->inputs()) {
294299
if (i->type()->isSubtypeOf(TensorType::get())) {
@@ -300,7 +305,7 @@ void PeelProfilingLoops(const std::shared_ptr<Graph>& graph) {
300305
};
301306

302307
LoopsPeeler lp(peel_predicate);
303-
lp.run(graph);
308+
return lp.run(graph);
304309
}
305310

306311
Node* PeelLoop(Node* n, size_t times) {
@@ -360,9 +365,12 @@ Node* PeelLoop(Node* n, size_t times) {
360365
return peeled_copy;
361366
}
362367

363-
void UnrollLoops(std::shared_ptr<Graph>& graph) {
364-
UnrollLoops(graph->block());
365-
EliminateDeadCode(graph);
368+
bool UnrollLoops(std::shared_ptr<Graph>& graph) {
369+
bool changed = UnrollLoops(graph->block());
370+
if (changed) {
371+
EliminateDeadCode(graph);
372+
}
373+
return changed;
366374
}
367375

368376
} // namespace jit

torch/csrc/jit/passes/loop_unrolling.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,19 @@
55
namespace torch {
66
namespace jit {
77

8-
TORCH_API void UnrollLoops(std::shared_ptr<Graph>& graph);
8+
// return true if graph is modified
9+
TORCH_API bool UnrollLoops(std::shared_ptr<Graph>& graph);
910

1011
TORCH_API Node* PeelLoop(Node* n, size_t times);
1112

12-
TORCH_API void PeelProfilingLoops(const std::shared_ptr<Graph>& graph);
13+
// return true if graph is modified
14+
TORCH_API bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph);
1315

1416
struct TORCH_API LoopsPeeler {
1517
LoopsPeeler(std::function<bool(Node* n)> callback, size_t num_iterations = 1)
1618
: callback_(std::move(callback)), num_iterations_(num_iterations) {}
1719

18-
void run(const std::shared_ptr<Graph>& graph);
20+
bool run(const std::shared_ptr<Graph>& graph);
1921

2022
private:
2123
void collectLoop(Node* n);

torch/csrc/jit/passes/peephole.cpp

+47-13
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@ struct PeepholeOptimizeImpl {
2424
PeepholeOptimizeImpl(
2525
const std::shared_ptr<Graph>& graph,
2626
bool disable_shape_peepholes)
27-
: graph_(graph), shape_peepholes_(!disable_shape_peepholes) {
28-
run(graph->block());
29-
PeepholeOptimizeListIdioms(graph);
30-
PeepholeOptimizeAliasSensitive(graph);
27+
: graph_(graph), shape_peepholes_(!disable_shape_peepholes) {}
28+
29+
bool run() {
30+
bool changed = optimizeBlock(graph_->block());
31+
changed |= PeepholeOptimizeListIdioms(graph_);
32+
changed |= PeepholeOptimizeAliasSensitive(graph_);
33+
return changed;
3134
}
3235

3336
// The intent for this optimization pass is to catch all of the small, easy to
@@ -39,12 +42,13 @@ struct PeepholeOptimizeImpl {
3942
//
4043
// TODO: Decide what kind of fixed point strategy we will have
4144
//
42-
void run(Block* block) {
45+
bool optimizeBlock(Block* block) {
46+
bool changed = false;
4347
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
4448
auto* node = *it;
4549

4650
for (Block* sub_block : node->blocks()) {
47-
run(sub_block);
51+
changed |= optimizeBlock(sub_block);
4852
}
4953

5054
if (node->kind() != prim::Constant) {
@@ -55,6 +59,7 @@ struct PeepholeOptimizeImpl {
5559
for (Value* output : node->outputs()) {
5660
if (output->type()->cast<NoneType>()) {
5761
output->replaceAllUsesWith(graph_->insertConstant(IValue()));
62+
changed = true;
5863
}
5964
}
6065
}
@@ -71,6 +76,7 @@ struct PeepholeOptimizeImpl {
7176
" (x._grad_sum_to_size(x, None) == x) is replaced with ",
7277
node->input(0)->debugName());
7378
node->output()->replaceAllUsesWith(node->input(0));
79+
changed = true;
7480
} else {
7581
auto uses = node->output()->uses();
7682
for (Use u : uses) {
@@ -82,6 +88,7 @@ struct PeepholeOptimizeImpl {
8288
" (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with ",
8389
node->inputs().at(0)->debugName());
8490
u.user->replaceInput(0, node->inputs().at(0));
91+
changed = true;
8592
}
8693
}
8794
}
@@ -102,6 +109,7 @@ struct PeepholeOptimizeImpl {
102109
" (x.expand(x.size()) == x) is replaced with ",
103110
node->namedInput(attr::self)->debugName());
104111
node->output()->replaceAllUsesWith(node->namedInput(attr::self));
112+
changed = true;
105113
}
106114
}
107115
} else if (node->matches("aten::t(Tensor self) -> Tensor")) {
@@ -113,6 +121,7 @@ struct PeepholeOptimizeImpl {
113121
" (x.t().t() == x) is replaced with ",
114122
input_node->input()->debugName());
115123
node->output()->replaceAllUsesWith(input_node->input());
124+
changed = true;
116125
}
117126
} else if (
118127
node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor") &&
@@ -127,6 +136,7 @@ struct PeepholeOptimizeImpl {
127136
" (x.type_as(y) == x) is replaced with ",
128137
node->input(0)->debugName());
129138
node->output()->replaceAllUsesWith(node->input(0));
139+
changed = true;
130140
}
131141
} else if (
132142
node->kind() == aten::Float || node->kind() == aten::Int ||
@@ -140,6 +150,7 @@ struct PeepholeOptimizeImpl {
140150
" (x.NumToTensor().TensorToNum() == x.NumToTensor()) is replaced with ",
141151
node->input()->debugName());
142152
node->output()->replaceAllUsesWith(input_node->input());
153+
changed = true;
143154
}
144155
} else if (
145156
node->matches("aten::size(Tensor self) -> int[]") &&
@@ -154,6 +165,7 @@ struct PeepholeOptimizeImpl {
154165
IValue ival(sizes);
155166
auto const_sizes_val = node->owningGraph()->insertConstant(ival);
156167
node->output()->replaceAllUsesWith(const_sizes_val);
168+
changed = true;
157169
}
158170
}
159171
} else if (
@@ -174,6 +186,7 @@ struct PeepholeOptimizeImpl {
174186
IValue ival(*ptt->sizes()[norm_index]);
175187
auto const_sizes_val = node->owningGraph()->insertConstant(ival);
176188
node->output()->replaceAllUsesWith(const_sizes_val);
189+
changed = true;
177190
}
178191
}
179192
}
@@ -187,6 +200,7 @@ struct PeepholeOptimizeImpl {
187200
IValue ival(at::isFloatingType(dtype));
188201
auto new_constant = node->owningGraph()->insertConstant(ival);
189202
node->output()->replaceAllUsesWith(new_constant);
203+
changed = true;
190204
}
191205
} else if (
192206
node->matches("aten::is_complex(Tensor self) -> bool") &&
@@ -220,6 +234,7 @@ struct PeepholeOptimizeImpl {
220234
" (True or False) with ",
221235
n.cond()->debugName());
222236
n.outputs().at(i)->replaceAllUsesWith(n.cond());
237+
changed = true;
223238
}
224239
}
225240
} else if (
@@ -240,6 +255,7 @@ struct PeepholeOptimizeImpl {
240255
GRAPH_UPDATE(
241256
"Folding ", getHeader(node), " to ", output->debugName());
242257
node->output()->replaceAllUsesWith(output);
258+
changed = true;
243259
}
244260
}
245261
} else if (
@@ -255,6 +271,7 @@ struct PeepholeOptimizeImpl {
255271
node->input(),
256272
" can't be optional");
257273
node->output()->replaceAllUsesWith(node->input());
274+
changed = true;
258275
}
259276
} else if (node->kind() == prim::unchecked_cast) {
260277
// unchecked_cast is not generated for tensor properties, so we are not
@@ -267,6 +284,7 @@ struct PeepholeOptimizeImpl {
267284
getHeader(node),
268285
" as input type subtypes output type");
269286
node->output()->replaceAllUsesWith(node->input());
287+
changed = true;
270288
}
271289
} else if (
272290
node->matches("prim::dtype(Tensor a) -> int") && shape_peepholes_) {
@@ -281,6 +299,7 @@ struct PeepholeOptimizeImpl {
281299
" with a type constant ",
282300
output->debugName());
283301
node->output()->replaceAllUsesWith(output);
302+
changed = true;
284303
}
285304
} else if (
286305
node->matches("prim::device(Tensor a) -> Device") &&
@@ -295,6 +314,7 @@ struct PeepholeOptimizeImpl {
295314
" with a device constant ",
296315
output->debugName());
297316
node->output()->replaceAllUsesWith(output);
317+
changed = true;
298318
}
299319
} else if (
300320
node->matches("aten::dim(Tensor self) -> int") && shape_peepholes_) {
@@ -309,6 +329,7 @@ struct PeepholeOptimizeImpl {
309329
" with a \"dim\" constant ",
310330
output->debugName());
311331
node->output()->replaceAllUsesWith(output);
332+
changed = true;
312333
}
313334
} else if (
314335
node->matches("prim::is_cuda(Tensor a) -> bool") &&
@@ -324,6 +345,7 @@ struct PeepholeOptimizeImpl {
324345
" with a is_cuda constant ",
325346
output->debugName());
326347
node->output()->replaceAllUsesWith(output);
348+
changed = true;
327349
}
328350
}
329351

@@ -333,6 +355,7 @@ struct PeepholeOptimizeImpl {
333355
// the limited speedup of these optimizations
334356
// runAliasingSensitivePeepholeTransformations(node);
335357
}
358+
return changed;
336359
}
337360

338361
// if either the inputs or outputs of an op alias graph's inputs or
@@ -345,10 +368,11 @@ struct PeepholeOptimizeImpl {
345368
// s += x
346369
// return s
347370
//
348-
void runAliasingSensitivePeepholeTransformations(Node* node) {
371+
bool runAliasingSensitivePeepholeTransformations(Node* node) {
349372
// this code is not currently enabled, see [aliasing sensitive
350373
// optimizations]
351374
TORCH_INTERNAL_ASSERT(false);
375+
bool changed = false;
352376
if (node->matches(
353377
"aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
354378
/*const_inputs=*/{attr::alpha, attr::other}) ||
@@ -363,6 +387,7 @@ struct PeepholeOptimizeImpl {
363387
" (x + 0 == x - 0 == x) is replaced with ",
364388
node->input(0)->debugName());
365389
node->output()->replaceAllUsesWith(node->input(0));
390+
changed = true;
366391
}
367392
} else if (
368393
node->matches(
@@ -378,16 +403,19 @@ struct PeepholeOptimizeImpl {
378403
" (x * 1 == x / 1 == x) is replaced with ",
379404
node->input(0)->debugName());
380405
node->output()->replaceAllUsesWith(node->input(0));
406+
changed = true;
381407
}
382408
}
409+
return changed;
383410
}
384411

385412
private:
386413
std::shared_ptr<Graph> graph_;
387414
bool shape_peepholes_;
388415
};
389416

390-
void FuseAddMM(Block* block) {
417+
bool FuseAddMM(Block* block) {
418+
bool changed = false;
391419
for (Node* node : block->nodes()) {
392420
// XXX: remember that if you want to simplify an expression by combining
393421
// multiple nodes into a different one, then you need to check that they
@@ -488,15 +516,17 @@ void FuseAddMM(Block* block) {
488516
" into ",
489517
addmm_value->debugName());
490518
node->output()->replaceAllUsesWith(addmm_value);
519+
changed = true;
491520
continue;
492521
}
493522
}
494523
}
495524
}
496525
for (Block* b : node->blocks()) {
497-
FuseAddMM(b);
526+
changed |= FuseAddMM(b);
498527
}
499528
}
529+
return changed;
500530
}
501531

502532
// FuseAddMM is a separate pass from peephole optimize because it is currently
@@ -506,17 +536,21 @@ void FuseAddMM(Block* block) {
506536
// since after ONNX translation we would see redundant Gemm ops with sub-optimal
507537
// inputs. This flag is exposed so that ONNX export can pass `true` to get the
508538
// fused behavior, but normal JIT peephole optimization is left alone.
509-
void FuseAddMM(const std::shared_ptr<Graph>& graph) {
510-
FuseAddMM(graph->block());
539+
bool FuseAddMM(const std::shared_ptr<Graph>& graph) {
540+
return FuseAddMM(graph->block());
511541
}
512542

513-
void PeepholeOptimize(
543+
bool PeepholeOptimize(
514544
const std::shared_ptr<Graph>& graph,
515545
bool addmm_fusion_enabled) {
516546
PeepholeOptimizeImpl peephole(graph, addmm_fusion_enabled);
547+
bool changed = peephole.run();
517548
GRAPH_DUMP("After PeepholeOptimize: ", graph);
518549
// Eliminate dead code created by any peephole passes we've just done
519-
EliminateDeadCode(graph->block());
550+
if (changed) {
551+
EliminateDeadCode(graph->block());
552+
}
553+
return changed;
520554
}
521555

522556
} // namespace jit

0 commit comments

Comments
 (0)