16
16
namespace torch {
17
17
namespace jit {
18
18
19
- c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant (const Node* n) {
19
+ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant (
20
+ const Node* n,
21
+ bool ignore_custom_classes) {
20
22
Stack stack;
21
23
for (auto input : n->inputs ()) {
22
24
if (auto ival = toIValue (input)) {
@@ -25,6 +27,7 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(const Node* n) {
25
27
return c10::nullopt;
26
28
}
27
29
}
30
+
28
31
switch (n->kind ()) {
29
32
case prim::ListUnpack: {
30
33
if (stack.back ().toList ().size () != n->outputs ().size ()) {
@@ -80,6 +83,12 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(const Node* n) {
80
83
return c10::nullopt;
81
84
}
82
85
}
86
+ // Weak form of const propagation
87
+ if (ignore_custom_classes) {
88
+ if (v.isCustomClass ()) {
89
+ return c10::nullopt;
90
+ }
91
+ }
83
92
}
84
93
return stack;
85
94
}
@@ -104,33 +113,40 @@ std::unordered_set<Symbol> skip_list = {
104
113
struct ConstantPropagator {
105
114
// Runs constant propagation with an aliasing db and checks if inputs or
106
115
// outputs might be mutated in the graph
107
- static ConstantPropagator WithAliasDb (std::shared_ptr<Graph> graph) {
108
- return ConstantPropagator (graph, true );
116
+ static ConstantPropagator WithAliasDb (
117
+ std::shared_ptr<Graph> graph,
118
+ bool ignore_custom_classes) {
119
+ return ConstantPropagator (std::move (graph), true , ignore_custom_classes);
109
120
}
110
121
111
122
// Runs constant propagation only on ops that clearly do not have aliased
112
123
// inputs or outputs without computing aliasing information
113
124
static ConstantPropagator NoAliasDb (std::shared_ptr<Graph> graph) {
114
- return ConstantPropagator (graph, false );
125
+ return ConstantPropagator (std::move ( graph), false , false );
115
126
}
116
127
117
128
void run () {
118
129
ConstantPropagation (graph_->block ());
119
130
}
120
131
121
132
private:
122
- ConstantPropagator (std::shared_ptr<Graph> graph, bool aliasing_types)
133
+ ConstantPropagator (
134
+ std::shared_ptr<Graph> graph,
135
+ bool aliasing_types,
136
+ bool ignore_custom_classes)
123
137
: graph_(std::move(graph)) {
124
138
if (aliasing_types) {
125
139
aliasDb_ = torch::make_unique<AliasDb>(graph_);
126
140
} else {
127
141
aliasDb_ = nullptr ;
128
142
}
143
+ ignore_custom_classes_ = ignore_custom_classes;
129
144
}
130
145
131
146
void propagateNode (Node* n) {
132
147
std::vector<IValue> outputs;
133
- if (auto outputs_opt = runNodeIfInputsAreConstant (n)) {
148
+ if (auto outputs_opt =
149
+ runNodeIfInputsAreConstant (n, ignore_custom_classes_)) {
134
150
outputs = std::move (outputs_opt.value ());
135
151
} else {
136
152
// The op failed to run, so we cannot continue constant-prop for it.
@@ -353,11 +369,15 @@ struct ConstantPropagator {
353
369
354
370
std::shared_ptr<Graph> graph_;
355
371
std::unique_ptr<AliasDb> aliasDb_;
372
+ bool ignore_custom_classes_;
356
373
};
357
374
} // anonymous namespace
358
375
359
- void ConstantPropagation (std::shared_ptr<Graph>& graph) {
360
- ConstantPropagator cp = ConstantPropagator::WithAliasDb (graph);
376
+ void ConstantPropagation (
377
+ std::shared_ptr<Graph>& graph,
378
+ bool ignore_custom_classes) {
379
+ ConstantPropagator cp =
380
+ ConstantPropagator::WithAliasDb (graph, ignore_custom_classes);
361
381
cp.run ();
362
382
EliminateDeadCode (graph);
363
383
GRAPH_DUMP (" After ConstantPropagation: " , graph);
0 commit comments