forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrestore_mutation.cpp
85 lines (69 loc) · 2.67 KB
/
restore_mutation.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <ATen/core/interned_strings.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/restore_mutation.h>
namespace torch {
namespace jit {
FunctionalToInplaceRewriter::FunctionalToInplaceRewriter(
std::shared_ptr<Graph> graph)
: aliasDb_(nullptr), graph_(std::move(graph)) {}
bool FunctionalToInplaceRewriter::CanBeInplace(Node* node) {
if (activation_type_promotion_mapping.find(node->kind()) ==
activation_type_promotion_mapping.end()) {
return false;
}
Symbol inplace_op =
Symbol::fromQualString(std::string(node->kind().toQualString()) + "_");
if (!inplace_op) {
return false;
}
// If type promotion is allowed, then perform dtype check
bool check_dtype = activation_type_promotion_mapping.at(node->kind());
Value* input = node->inputs().at(0);
Value* output = node->outputs().at(0);
auto inputDtype = input->type()->expect<TensorType>()->scalarType();
auto outputDtype = output->type()->expect<TensorType>()->scalarType();
// In general, we don't need to check shape for activation ops as they
// element-wise. But for those where type promotion could happen, we need to
// make sure the dtype of input and output are the same. For now the dtype
// checking will always fail until the type inference is ready.
if (check_dtype &&
(!inputDtype || !outputDtype ||
inputDtype.value() != outputDtype.value())) {
return false;
}
// Skip if input's def node has side effect or input has alias
if (MutationRemover::hasSideEffectOrAlias(input, getOrCreateAliasDb())) {
return false;
}
// If x has more than one use, skip the converson.
// TODO: Use liveness analysis to catch more general scenario
return (input->uses().size() == 1);
}
bool FunctionalToInplaceRewriter::FunctionalToInplace(Block* block) {
bool changed = false;
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
auto* node = *it;
it++;
for (Block* sub_block : node->blocks()) {
changed |= FunctionalToInplace(sub_block);
}
if (!CanBeInplace(node)) {
continue;
}
changed = true;
Node* inplace_node = node->replaceWithNewSymbol(
Symbol::fromQualString(node->schema().name() + "_"));
inplace_node->output()->replaceAllUsesWith(node->inputs().at(0));
getOrCreateAliasDb()->replaceWithNewValue(
node->output(), inplace_node->output());
node->destroy();
}
return changed;
}
bool FunctionalToInplaceActivation(const std::shared_ptr<Graph>& graph) {
FunctionalToInplaceRewriter rewriter(graph);
return rewriter.FunctionalToInplace(graph->block());
}
} // namespace jit
} // namespace torch