Skip to content

Commit b6c8f79

Browse files
author
Abhishek Kulkarni
committed
Remove "static inputs" for reduction ops
1 parent 940d415 commit b6c8f79

File tree

2 files changed

+21
-59
lines changed

2 files changed

+21
-59
lines changed

ngraph_bridge/ngraph_builder.cc

Lines changed: 21 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2079,45 +2079,8 @@ static Status TranslateNonMaxSuppressionV4Op(
20792079
return Status::OK();
20802080
}
20812081

2082-
static Status TranslateReduceOp(
2083-
const Node* op, const std::vector<const Tensor*>& static_input_map,
2084-
Builder::OpMap& ng_op_map,
2085-
std::function<ng::Output<ng::Node>(ng::Output<ng::Node>,
2086-
ng::Output<ng::Node>, const bool)>
2087-
create_ng_node) {
2088-
ng::Output<ng::Node> ng_input;
2089-
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, ng_input));
2090-
bool tf_keep_dims;
2091-
if (GetNodeAttr(op->attrs(), "keep_dims", &tf_keep_dims) != Status::OK()) {
2092-
tf_keep_dims = false;
2093-
}
2094-
2095-
std::vector<int64> axes;
2096-
TF_RETURN_IF_ERROR(GetStaticInputVector(op, 1, static_input_map, &axes));
2097-
2098-
ng::Shape input_shape = ng_input.get_shape();
2099-
size_t input_rank = input_shape.size();
2100-
2101-
TF_RETURN_IF_ERROR(CheckAxisDimInRange(axes, input_rank));
2102-
2103-
std::vector<size_t> ng_reduction_axes_vect(axes.size());
2104-
std::transform(
2105-
axes.begin(), axes.end(), ng_reduction_axes_vect.begin(),
2106-
[input_rank](int idx) { return idx + (idx < 0 ? (int)input_rank : 0); });
2107-
auto ng_reduction_axes = ConstructNgNode<opset::Constant>(
2108-
op->name(), ng::element::i64, ng::Shape{ng_reduction_axes_vect.size()},
2109-
ng_reduction_axes_vect);
2110-
2111-
ng::Output<ng::Node> ng_node =
2112-
create_ng_node(ng_input, ng_reduction_axes, tf_keep_dims);
2113-
Builder::SetTracingInfo(op->name(), ng_node);
2114-
2115-
SaveNgOp(ng_op_map, op->name(), ng_node);
2116-
return Status::OK();
2117-
}
2118-
21192082
template <typename T>
2120-
static Status TranslateDirectReduceOp(
2083+
static Status TranslateReduceOp(
21212084
const Node* op, const std::vector<const Tensor*>& static_input_map,
21222085
Builder::OpMap& ng_op_map) {
21232086
// ensure its either an arithmetic or a logical reduction
@@ -2127,13 +2090,19 @@ static Status TranslateDirectReduceOp(
21272090
"Expected node to be either a valid logical or arithmetic reduction "
21282091
"type");
21292092
}
2130-
return TranslateReduceOp(
2131-
op, static_input_map, ng_op_map,
2132-
[&op](ng::Output<ng::Node> ng_input,
2133-
ng::Output<ng::Node> ng_reduction_axes, const bool keep_dims) {
2134-
return ConstructNgNode<T>(op->name(), ng_input, ng_reduction_axes,
2135-
keep_dims);
2136-
});
2093+
2094+
shared_ptr<ng::Node> ng_input, ng_reduction_indices;
2095+
TF_RETURN_IF_ERROR(
2096+
GetInputNodes(ng_op_map, op, &ng_input, &ng_reduction_indices));
2097+
bool keep_dims;
2098+
if (GetNodeAttr(op->attrs(), "keep_dims", &keep_dims) != Status::OK()) {
2099+
keep_dims = false;
2100+
}
2101+
2102+
auto ng_node =
2103+
ConstructNgNode<T>(op->name(), ng_input, ng_reduction_indices, keep_dims);
2104+
SaveNgOp(ng_op_map, op->name(), ng_node);
2105+
return Status::OK();
21372106
}
21382107

21392108
static Status TranslateOneHotOp(
@@ -3002,8 +2971,8 @@ const static std::map<
30022971
{"Add", TranslateBinaryOp<opset::Add>},
30032972
{"AddN", TranslateAddNOp},
30042973
{"AddV2", TranslateBinaryOp<opset::Add>},
3005-
{"Any", TranslateDirectReduceOp<opset::ReduceLogicalOr>},
3006-
{"All", TranslateDirectReduceOp<opset::ReduceLogicalAnd>},
2974+
{"Any", TranslateReduceOp<opset::ReduceLogicalOr>},
2975+
{"All", TranslateReduceOp<opset::ReduceLogicalAnd>},
30072976
{"ArgMax", TranslateArgMaxOp},
30082977
{"ArgMin", TranslateArgMinOp},
30092978
{"Asin", TranslateUnaryOp<opset::Asin>},
@@ -3053,13 +3022,13 @@ const static std::map<
30533022
{"LogicalNot", TranslateUnaryOp<opset::LogicalNot>},
30543023
{"LogicalOr", TranslateBinaryOp<opset::LogicalOr>},
30553024
{"MatMul", TranslateMatMulOp},
3056-
{"Max", TranslateDirectReduceOp<opset::ReduceMax>},
3025+
{"Max", TranslateReduceOp<opset::ReduceMax>},
30573026
{"Maximum", TranslateBinaryOp<opset::Maximum>},
30583027
{"MaxPool", TranslateMaxPoolOp},
30593028
{"MaxPool3D", TranslateMaxPool3DOp},
30603029
{"NonMaxSuppressionV4", TranslateNonMaxSuppressionV4Op},
3061-
{"Mean", TranslateDirectReduceOp<opset::ReduceMean>},
3062-
{"Min", TranslateDirectReduceOp<opset::ReduceMin>},
3030+
{"Mean", TranslateReduceOp<opset::ReduceMean>},
3031+
{"Min", TranslateReduceOp<opset::ReduceMin>},
30633032
{"Minimum", TranslateBinaryOp<opset::Minimum>},
30643033
{"MirrorPad", TranslatePadOp},
30653034
{"Mul", TranslateBinaryOp<opset::Multiply>},
@@ -3077,7 +3046,7 @@ const static std::map<
30773046
{"Pow", TranslateBinaryOp<opset::Power>},
30783047
// PreventGradient is just Identity in dataflow terms, so reuse that.
30793048
{"PreventGradient", TranslateIdentityOp},
3080-
{"Prod", TranslateDirectReduceOp<opset::ReduceProd>},
3049+
{"Prod", TranslateReduceOp<opset::ReduceProd>},
30813050
{"Rank", TranslateRankOp},
30823051
{"RealDiv", TranslateBinaryOp<opset::Divide>},
30833052
{"Reciprocal", TranslateReciprocalOp},
@@ -3106,7 +3075,7 @@ const static std::map<
31063075
{"Squeeze", TranslateSqueezeOp},
31073076
{"StridedSlice", TranslateStridedSliceOp},
31083077
{"Sub", TranslateBinaryOp<opset::Subtract>},
3109-
{"Sum", TranslateDirectReduceOp<opset::ReduceSum>},
3078+
{"Sum", TranslateReduceOp<opset::ReduceSum>},
31103079
{"Tan", TranslateUnaryOp<opset::Tan>},
31113080
{"Tanh", TranslateUnaryOp<opset::Tanh>},
31123081
{"Tile", TranslateTileOp},

ngraph_bridge/ngraph_mark_for_clustering.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,32 +195,25 @@ const std::map<std::string, SetAttributesFunction>& GetAttributeSetters() {
195195

196196
if (!initialized) {
197197
// Set Additional Attributes (if any)
198-
set_attributes_map["Any"] = SetStaticInputs({1});
199-
set_attributes_map["All"] = SetStaticInputs({1});
200198
set_attributes_map["ArgMax"] = SetStaticInputs({1});
201199
set_attributes_map["ArgMin"] = SetStaticInputs({1});
202200
set_attributes_map["ConcatV2"] = SetStaticInputs({-1});
203201
set_attributes_map["Conv2DBackpropInput"] = SetStaticInputs({0});
204202
set_attributes_map["ExpandDims"] = SetStaticInputs({1});
205203
set_attributes_map["Fill"] = SetStaticInputs({0});
206204
set_attributes_map["GatherV2"] = SetStaticInputs({2});
207-
set_attributes_map["Max"] = SetStaticInputs({1});
208-
set_attributes_map["Mean"] = SetStaticInputs({1});
209-
set_attributes_map["Min"] = SetStaticInputs({1});
210205
set_attributes_map["MirrorPad"] = SetStaticInputs({1});
211206
set_attributes_map["NonMaxSuppressionV4"] = SetStaticInputs({2, 3, 4});
212207
set_attributes_map["OneHot"] = SetStaticInputs({1});
213208
set_attributes_map["Pad"] = SetStaticInputs({1});
214209
set_attributes_map["PadV2"] = SetStaticInputs({1, 2});
215-
set_attributes_map["Prod"] = SetStaticInputs({1});
216210
set_attributes_map["Reshape"] = SetStaticInputs({1});
217211
set_attributes_map["Shape"] = SetStaticInputs({0});
218212
set_attributes_map["ScatterNd"] = SetStaticInputs({2});
219213
set_attributes_map["Slice"] = SetStaticInputs({1, 2});
220214
set_attributes_map["Split"] = SetStaticInputs({0});
221215
set_attributes_map["SplitV"] = SetStaticInputs({1, 2});
222216
set_attributes_map["StridedSlice"] = SetStaticInputs({1, 2, 3});
223-
set_attributes_map["Sum"] = SetStaticInputs({1});
224217
set_attributes_map["TopKV2"] = SetStaticInputs({1});
225218
set_attributes_map["Tile"] = SetStaticInputs({1});
226219
set_attributes_map["Transpose"] = SetStaticInputs({1});

0 commit comments

Comments
 (0)