forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrequires_grad_analysis.cpp
162 lines (140 loc) · 5.26 KB
/
requires_grad_analysis.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <ATen/core/jit_type.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <vector>
namespace torch {
namespace jit {
namespace {
bool getRequiresGrad(Value* value) {
return value->requires_grad();
}
void setRequiresGrad(Value* value, bool req_value) {
if (auto type = value->type()->cast<TensorType>()) {
value->setType(type->withRequiresGrad(req_value));
}
}
void setRequiresGrad(
at::ArrayRef<Value*> outputs,
const std::vector<bool>& values) {
AT_ASSERT(outputs.size() == values.size());
for (const auto i : c10::irange(values.size())) {
setRequiresGrad(outputs[i], values[i]);
}
}
void setRequiresGrad(Node* node, const std::vector<bool>& values) {
setRequiresGrad(node->outputs(), values);
}
std::vector<bool> bitwiseOr(std::vector<bool> a, const std::vector<bool>& b) {
AT_ASSERT(a.size() == b.size());
for (const auto i : c10::irange(a.size())) {
a[i] = a[i] || b[i];
}
return a;
}
void PropagateRequiresGradSimpleNode(Node* node) {
static const OperatorSet comparison_ops = {
"aten::lt(Tensor self, Tensor other) -> Tensor",
"aten::le(Tensor self, Tensor other) -> Tensor",
"aten::gt(Tensor self, Tensor other) -> Tensor",
"aten::ge(Tensor self, Tensor other) -> Tensor",
"aten::eq(Tensor self, Tensor other) -> Tensor",
"aten::ne(Tensor self, Tensor other) -> Tensor",
"aten::lt(Tensor self, Scalar other) -> Tensor",
"aten::le(Tensor self, Scalar other) -> Tensor",
"aten::gt(Tensor self, Scalar other) -> Tensor",
"aten::ge(Tensor self, Scalar other) -> Tensor",
"aten::eq(Tensor self, Scalar other) -> Tensor",
"aten::ne(Tensor self, Scalar other) -> Tensor",
};
// NOLINTNEXTLINE(bugprone-branch-clone)
if (node->isMemberOf(comparison_ops)) {
return setRequiresGrad(node->output(), false);
} else if (node->matches(
"aten::type_as(Tensor self, Tensor other) -> Tensor")) {
return setRequiresGrad(node->output(), node->input(0)->requires_grad());
} else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
return setRequiresGrad(node->output(), false);
} else if (node->kind() == aten::tensor) {
if (auto grad_index =
node->schema().argumentIndexWithName("requires_grad")) {
if (auto const_arg = constant_as<bool>(node->inputs().at(*grad_index))) {
return setRequiresGrad(node->output(), *const_arg);
}
}
if (auto type = node->output()->type()->cast<TensorType>()) {
if (type->scalarType()) {
setRequiresGrad(
node->output(),
autograd::isDifferentiableType(*type->scalarType()));
}
}
return;
}
auto inputs = node->inputs();
auto outputs = node->outputs();
bool should_require =
std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
for (Value* output : outputs) {
if (auto type = output->type()->cast<TensorType>()) {
if (type->scalarType()) {
setRequiresGrad(
output,
should_require &&
autograd::isDifferentiableType(*type->scalarType()));
}
}
}
}
void PropagateRequiresGrad(Block* block);
void PropagateRequiresGrad(Node* node) {
if (node->kind() == prim::If) {
auto blocks = node->blocks();
auto true_block = blocks.at(0);
auto false_block = blocks.at(1);
PropagateRequiresGrad(true_block);
PropagateRequiresGrad(false_block);
auto outputs_require = bitwiseOr(
fmap(true_block->outputs(), getRequiresGrad),
fmap(false_block->outputs(), getRequiresGrad));
setRequiresGrad(node, outputs_require);
} else if (node->kind() == prim::Loop) {
auto body = node->blocks().at(0);
std::vector<bool> loop_inputs_require =
fmap(node->inputs().slice(2), getRequiresGrad);
std::vector<bool> body_inputs_require = loop_inputs_require;
std::vector<bool> body_outputs_require(node->outputs().size(), false);
std::vector<bool> new_body_inputs_require = body_inputs_require;
std::vector<bool> new_body_outputs_require = body_outputs_require;
// continue iterating until the results have converged
do {
body_inputs_require = new_body_inputs_require;
body_outputs_require = new_body_outputs_require;
new_body_inputs_require =
bitwiseOr(body_inputs_require, body_outputs_require);
setRequiresGrad(
body->param_node()->outputs().slice(1), new_body_inputs_require);
PropagateRequiresGrad(body);
new_body_outputs_require =
fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
} while (new_body_inputs_require != body_inputs_require ||
new_body_outputs_require != body_outputs_require);
setRequiresGrad(node, bitwiseOr(body_outputs_require, loop_inputs_require));
} else {
PropagateRequiresGradSimpleNode(node);
}
}
void PropagateRequiresGrad(Block* block) {
for (Node* node : block->nodes()) {
PropagateRequiresGrad(node);
}
}
} // anonymous namespace
void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
PropagateRequiresGrad(graph->block());
}
} // namespace jit
} // namespace torch