Skip to content

Commit 39b4701

Browse files
Hao Lufacebook-github-bot
Hao Lu
authored andcommitted
[caffe2][redo] Reimplement RemoveOpsByType with SSA (pytorch#41606)
Summary: Pull Request resolved: pytorch#41606 The previous diff (D22220798 (pytorch@59294fb) and D22220797) was recently reverted (D22492356 (pytorch@28291d3), D22492355) because of a bug associated with the op AsyncIf. The AsyncIf op has net_defs as args and the SSA rewriting didn't take that into account. It has a special path for the op If, but not for AsyncIf. Several changes I made to fix the bug: 1) Add op AsyncIf to the special path for If op in SSA rewriting 2) clear inputs/outputs of the netdefs that are args in If/AsyncIf ops because they're no longer valid 3) revert renamed inputs/outputs in the arg netdefs that are in the external_outputs in the parent netdef 2) and 3) are existing bugs in the `SsaRewrite` function that were just never exposed before. The algorithm for `RemoveOpsByType` is the same as in my previous diff D22220798 (pytorch@59294fb). The only new changes in this diff are in `onnx::SsaRewrite` and a few newly added unit tests. (Note: this ignores all push blocking failures!) Reviewed By: yinghai Differential Revision: D22588652 fbshipit-source-id: ebb68ecd1662ea2bae14d4be8f61a75cd8b7e3e6
1 parent 349c405 commit 39b4701

File tree

6 files changed

+307
-303
lines changed

6 files changed

+307
-303
lines changed

caffe2/onnx/onnx_exporter.cc

+76-38
Original file line numberDiff line numberDiff line change
@@ -103,31 +103,6 @@ NodeProto AddShapeNode(const std::string& input, const std::string& output) {
103103
return shape_node;
104104
}
105105

106-
} // namespace
107-
108-
::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
109-
caffe2::TensorProto::DataType t) {
110-
#define CAFFE2_TO_ONNX_TYPE(x) \
111-
case (caffe2::TensorProto::x): \
112-
return ::ONNX_NAMESPACE::TensorProto::x
113-
switch (t) {
114-
CAFFE2_TO_ONNX_TYPE(FLOAT);
115-
CAFFE2_TO_ONNX_TYPE(BOOL);
116-
CAFFE2_TO_ONNX_TYPE(INT8);
117-
CAFFE2_TO_ONNX_TYPE(UINT8);
118-
CAFFE2_TO_ONNX_TYPE(UINT16);
119-
CAFFE2_TO_ONNX_TYPE(INT16);
120-
CAFFE2_TO_ONNX_TYPE(INT32);
121-
CAFFE2_TO_ONNX_TYPE(INT64);
122-
CAFFE2_TO_ONNX_TYPE(FLOAT16);
123-
default:
124-
LOG(WARNING) << "Unsupported Caffe2 tensor type: " << t
125-
<< ", fallback to FLOAT";
126-
return ::ONNX_NAMESPACE::TensorProto::FLOAT;
127-
}
128-
#undef CAFFE2_TO_ONNX_TYPE
129-
}
130-
131106
void collectExternalsFromIfOpSubnet(
132107
const NetDef* net,
133108
std::vector<std::string>* input,
@@ -158,6 +133,9 @@ void rewriteSubnet(
158133
Argument* arg,
159134
std::map<std::string, std::string> oldname_to_newname) {
160135
NetDef* net = arg->mutable_n();
136+
// clear external inputs and outputs since they're no longer valid
137+
net->mutable_external_input()->Clear();
138+
net->mutable_external_output()->Clear();
161139
for (auto& op : *(net->mutable_op())) {
162140
for (auto& input : *(op.mutable_input())) {
163141
if (oldname_to_newname.find(input) != oldname_to_newname.end()) {
@@ -245,6 +223,72 @@ void ssaRewriteForIfOp(
245223
}
246224
}
247225

226+
void revertRenamedExternalOutput(
227+
OperatorDef* op,
228+
const std::unordered_map<std::string, std::string>&
229+
renamed_external_outputs) {
230+
for (auto& input : *(op->mutable_input())) {
231+
const auto it = renamed_external_outputs.find(input);
232+
if (it != renamed_external_outputs.end()) {
233+
input = it->second;
234+
}
235+
}
236+
for (auto& output : *(op->mutable_output())) {
237+
const auto it = renamed_external_outputs.find(output);
238+
if (it != renamed_external_outputs.end()) {
239+
output = it->second;
240+
}
241+
}
242+
}
243+
244+
void revertRenamedExternalOutputForIfOp(
245+
OperatorDef* if_op,
246+
const std::unordered_map<std::string, std::string>&
247+
renamed_external_outputs) {
248+
ArgumentHelper helper(*if_op);
249+
Argument *then_arg = nullptr, *else_arg = nullptr;
250+
251+
if (helper.HasSingleArgumentOfType<NetDef>("then_net")) {
252+
then_arg = getArgumentFromName(if_op, "then_net");
253+
NetDef* net = then_arg->mutable_n();
254+
for (auto& op : *(net->mutable_op())) {
255+
revertRenamedExternalOutput(&op, renamed_external_outputs);
256+
}
257+
}
258+
if (helper.HasSingleArgumentOfType<NetDef>("else_net")) {
259+
else_arg = getArgumentFromName(if_op, "else_net");
260+
NetDef* net = else_arg->mutable_n();
261+
for (auto& op : *(net->mutable_op())) {
262+
revertRenamedExternalOutput(&op, renamed_external_outputs);
263+
}
264+
}
265+
}
266+
267+
} // namespace
268+
269+
::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
270+
caffe2::TensorProto::DataType t) {
271+
#define CAFFE2_TO_ONNX_TYPE(x) \
272+
case (caffe2::TensorProto::x): \
273+
return ::ONNX_NAMESPACE::TensorProto::x
274+
switch (t) {
275+
CAFFE2_TO_ONNX_TYPE(FLOAT);
276+
CAFFE2_TO_ONNX_TYPE(BOOL);
277+
CAFFE2_TO_ONNX_TYPE(INT8);
278+
CAFFE2_TO_ONNX_TYPE(UINT8);
279+
CAFFE2_TO_ONNX_TYPE(UINT16);
280+
CAFFE2_TO_ONNX_TYPE(INT16);
281+
CAFFE2_TO_ONNX_TYPE(INT32);
282+
CAFFE2_TO_ONNX_TYPE(INT64);
283+
CAFFE2_TO_ONNX_TYPE(FLOAT16);
284+
default:
285+
LOG(WARNING) << "Unsupported Caffe2 tensor type: " << t
286+
<< ", fallback to FLOAT";
287+
return ::ONNX_NAMESPACE::TensorProto::FLOAT;
288+
}
289+
#undef CAFFE2_TO_ONNX_TYPE
290+
}
291+
248292
std::unordered_map<std::string, std::string> SsaRewrite(
249293
caffe2::NetDef* init_net,
250294
caffe2::NetDef* pred_net) {
@@ -288,13 +332,13 @@ std::unordered_map<std::string, std::string> SsaRewrite(
288332
}
289333
}
290334
// Special SSA Rewrite for subnet of If Operator
291-
if (op.type() == "If") {
335+
if (op.type() == "If" || op.type() == "AsyncIf") {
292336
ssaRewriteForIfOp(&op, &blob_versions, &is_initialized_tensor);
293337
}
294338
for (auto& output : *op.mutable_output()) {
295339
auto it = blob_versions.find(output);
296340
if (it != blob_versions.end()) {
297-
if (op.type() != "If") {
341+
if (op.type() != "If" && op.type() != "AsyncIf") {
298342
if (is_initialized_tensor.count(output) == 0) {
299343
it->second += 1;
300344
} else {
@@ -338,17 +382,11 @@ std::unordered_map<std::string, std::string> SsaRewrite(
338382
// Use the mapping to find if the input or output of an op was a renamed
339383
// external output. If so replace it with its original name.
340384
for (auto& op : *pred_net->mutable_op()) {
341-
for (auto& input : *op.mutable_input()) {
342-
const auto it = renamed_external_outputs.find(input);
343-
if (it != renamed_external_outputs.end()) {
344-
input = it->second;
345-
}
346-
}
347-
for (auto& output : *op.mutable_output()) {
348-
const auto it = renamed_external_outputs.find(output);
349-
if (it != renamed_external_outputs.end()) {
350-
output = it->second;
351-
}
385+
// If/AsyncIf needs special handling
386+
if (op.type() == "If" || op.type() == "AsyncIf") {
387+
revertRenamedExternalOutputForIfOp(&op, renamed_external_outputs);
388+
} else {
389+
revertRenamedExternalOutput(&op, renamed_external_outputs);
352390
}
353391
}
354392
}

caffe2/opt/tvm_transformer.cc

+4-9
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,15 @@ void TvmTransformer::transform(
156156
std::unordered_set<std::string> weights(
157157
weight_names.begin(), weight_names.end());
158158

159-
// SSA Rewrite the net
160-
auto shape_hints_mapped =
161-
ssaRewriteAndMapNames(ws, pred_net, input_shape_hints);
162-
163-
// Populate shape info
164-
Workspace mapped_ws(ws, input_mapping_);
159+
// input_shape_hints should only contain shapes of inputs and not activations
165160
ShapeInfoMap shape_hints;
166161
if (!opts_.profiling_based_jit) {
167-
shape_hints = inferShapes(
168-
&mapped_ws, pred_net, shape_hints_mapped, opts_.bound_shape_spec);
162+
shape_hints =
163+
inferShapes(ws, pred_net, input_shape_hints, opts_.bound_shape_spec);
169164
}
170165

171166
if (opts_.debug) {
172-
dumpNet(*pred_net, shape_hints, "debug_ssa_net.pbtxt");
167+
dumpNet(*pred_net, shape_hints, "debug_net.pbtxt");
173168
}
174169

175170
// We are ready to transform the net

caffe2/predictor/InferenceGraph.h

+2
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,7 @@ struct InferenceGraph {
2020
std::vector<std::string> input_names;
2121
std::vector<std::string> output_names;
2222
std::vector<std::string> parameter_names;
23+
24+
bool predictor_net_ssa_rewritten{false};
2325
};
2426
} // namespace caffe2

0 commit comments

Comments
 (0)