Skip to content

Commit 9383b70

Browse files
author
Abhishek Kulkarni
committed
Remove "static inputs" for reduction ops
1 parent a493480 commit 9383b70

File tree

2 files changed

+21
-60
lines changed

2 files changed

+21
-60
lines changed

ngraph_bridge/ngraph_builder.cc

Lines changed: 21 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,45 +2400,8 @@ static Status TranslateNonMaxSuppressionV4Op(
24002400
return Status::OK();
24012401
}
24022402

2403-
static Status TranslateReduceOp(
2404-
const Node* op, const std::vector<const Tensor*>& static_input_map,
2405-
Builder::OpMap& ng_op_map,
2406-
std::function<std::shared_ptr<ng::Node>(
2407-
std::shared_ptr<ng::Node>, std::shared_ptr<ng::Node>, const bool)>
2408-
create_ng_node) {
2409-
shared_ptr<ng::Node> ng_input;
2410-
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, &ng_input));
2411-
bool tf_keep_dims;
2412-
if (GetNodeAttr(op->attrs(), "keep_dims", &tf_keep_dims) != Status::OK()) {
2413-
tf_keep_dims = false;
2414-
}
2415-
2416-
std::vector<int64> axes;
2417-
TF_RETURN_IF_ERROR(GetStaticInputVector(op, 1, static_input_map, &axes));
2418-
2419-
ng::Shape input_shape = ng_input->get_shape();
2420-
size_t input_rank = input_shape.size();
2421-
2422-
TF_RETURN_IF_ERROR(CheckAxisDimInRange(axes, input_rank));
2423-
2424-
std::vector<size_t> ng_reduction_axes_vect(axes.size());
2425-
std::transform(
2426-
axes.begin(), axes.end(), ng_reduction_axes_vect.begin(),
2427-
[input_rank](int idx) { return idx + (idx < 0 ? (int)input_rank : 0); });
2428-
auto ng_reduction_axes = ConstructNgNode<ng::opset3::Constant>(
2429-
op->name(), ng::element::i64, ng::Shape{ng_reduction_axes_vect.size()},
2430-
ng_reduction_axes_vect);
2431-
2432-
std::shared_ptr<ng::Node> ng_node =
2433-
create_ng_node(ng_input, ng_reduction_axes, tf_keep_dims);
2434-
Builder::SetTracingInfo(op->name(), ng_node);
2435-
2436-
SaveNgOp(ng_op_map, op->name(), ng_node);
2437-
return Status::OK();
2438-
}
2439-
24402403
template <typename T>
2441-
static Status TranslateDirectReduceOp(
2404+
static Status TranslateReduceOp(
24422405
const Node* op, const std::vector<const Tensor*>& static_input_map,
24432406
Builder::OpMap& ng_op_map) {
24442407
// ensure its either an arithmetic or a logical reduction
@@ -2448,13 +2411,19 @@ static Status TranslateDirectReduceOp(
24482411
"Expected node to be either a valid logical or arithmetic reduction "
24492412
"type");
24502413
}
2451-
return TranslateReduceOp(
2452-
op, static_input_map, ng_op_map,
2453-
[&op](std::shared_ptr<ng::Node> ng_input,
2454-
std::shared_ptr<ng::Node> ng_reduction_axes, const bool keep_dims) {
2455-
return ConstructNgNode<T>(op->name(), ng_input, ng_reduction_axes,
2456-
keep_dims);
2457-
});
2414+
2415+
shared_ptr<ng::Node> ng_input, ng_reduction_indices;
2416+
TF_RETURN_IF_ERROR(
2417+
GetInputNodes(ng_op_map, op, &ng_input, &ng_reduction_indices));
2418+
bool keep_dims;
2419+
if (GetNodeAttr(op->attrs(), "keep_dims", &keep_dims) != Status::OK()) {
2420+
keep_dims = false;
2421+
}
2422+
2423+
std::shared_ptr<ng::Node> ng_node =
2424+
ConstructNgNode<T>(op->name(), ng_input, ng_reduction_indices, keep_dims);
2425+
SaveNgOp(ng_op_map, op->name(), ng_node);
2426+
return Status::OK();
24582427
}
24592428

24602429
static Status TranslateOneHotOp(
@@ -3908,8 +3877,8 @@ const static std::map<
39083877
{"Add", TranslateBinaryOp<ngraph::opset3::Add>},
39093878
{"AddN", TranslateAddNOp},
39103879
{"AddV2", TranslateBinaryOp<ngraph::opset3::Add>},
3911-
{"Any", TranslateDirectReduceOp<ng::opset3::ReduceLogicalOr>},
3912-
{"All", TranslateDirectReduceOp<ng::opset3::ReduceLogicalAnd>},
3880+
{"Any", TranslateReduceOp<ng::opset3::ReduceLogicalOr>},
3881+
{"All", TranslateReduceOp<ng::opset3::ReduceLogicalAnd>},
39133882
{"ArgMax", TranslateArgMinMaxOp<ng::op::ArgMax>},
39143883
{"ArgMin", TranslateArgMinMaxOp<ng::op::ArgMin>},
39153884
{"Asin", TranslateUnaryOp<ngraph::opset3::Asin>},
@@ -3961,13 +3930,13 @@ const static std::map<
39613930
{"LogicalNot", TranslateUnaryOp<ngraph::opset3::LogicalNot>},
39623931
{"LogicalOr", TranslateBinaryOp<ngraph::opset3::LogicalOr>},
39633932
{"MatMul", TranslateMatMulOp},
3964-
{"Max", TranslateDirectReduceOp<ng::opset3::ReduceMax>},
3933+
{"Max", TranslateReduceOp<ng::opset3::ReduceMax>},
39653934
{"Maximum", TranslateBinaryOp<ngraph::opset3::Maximum>},
39663935
{"MaxPool", TranslateMaxPoolOp},
39673936
{"MaxPool3D", TranslateMaxPool3DOp},
39683937
{"NonMaxSuppressionV4", TranslateNonMaxSuppressionV4Op},
3969-
{"Mean", TranslateDirectReduceOp<ng::opset3::ReduceMean>},
3970-
{"Min", TranslateDirectReduceOp<ng::opset3::ReduceMin>},
3938+
{"Mean", TranslateReduceOp<ng::opset3::ReduceMean>},
3939+
{"Min", TranslateReduceOp<ng::opset3::ReduceMin>},
39713940
{"Minimum", TranslateBinaryOp<ngraph::opset3::Minimum>},
39723941
{"MirrorPad", TranslatePadOp},
39733942
{"Mul", TranslateBinaryOp<ngraph::opset3::Multiply>},
@@ -3985,7 +3954,7 @@ const static std::map<
39853954
{"Pow", TranslateBinaryOp<ngraph::opset3::Power>},
39863955
// PreventGradient is just Identity in data-flow terms, so reuse that.
39873956
{"PreventGradient", TranslateIdentityOp},
3988-
{"Prod", TranslateDirectReduceOp<ng::opset3::ReduceProd>},
3957+
{"Prod", TranslateReduceOp<ng::opset3::ReduceProd>},
39893958
{"QuantizeAndDequantizeV2", TranslateQuantizeAndDequantizeV2Op},
39903959
{"QuantizedAvgPool", TranslateQuantizedAvgPoolOp},
39913960
{"QuantizedConcat", TranslateQuantizedConcatOp},
@@ -4029,7 +3998,7 @@ const static std::map<
40293998
{"Squeeze", TranslateSqueezeOp},
40303999
{"StridedSlice", TranslateStridedSliceOp},
40314000
{"Sub", TranslateBinaryOp<ngraph::opset3::Subtract>},
4032-
{"Sum", TranslateDirectReduceOp<ng::opset3::ReduceSum>},
4001+
{"Sum", TranslateReduceOp<ng::opset3::ReduceSum>},
40334002
{"Tan", TranslateUnaryOp<ngraph::opset3::Tan>},
40344003
{"Tanh", TranslateUnaryOp<ngraph::opset3::Tanh>},
40354004
{"Tile", TranslateTileOp},

ngraph_bridge/ngraph_mark_for_clustering.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,25 +196,18 @@ const std::map<std::string, SetAttributesFunction>& GetAttributeSetters() {
196196

197197
if (!initialized) {
198198
// Set Additional Attributes (if any)
199-
set_attributes_map["Any"] = SetStaticInputs({1});
200-
set_attributes_map["All"] = SetStaticInputs({1});
201199
set_attributes_map["ArgMax"] = SetStaticInputs({1});
202200
set_attributes_map["ArgMin"] = SetStaticInputs({1});
203201
set_attributes_map["ConcatV2"] = SetStaticInputs({-1});
204202
set_attributes_map["Conv2DBackpropInput"] = SetStaticInputs({0});
205203
set_attributes_map["ExpandDims"] = SetStaticInputs({1});
206204
set_attributes_map["Fill"] = SetStaticInputs({0});
207205
set_attributes_map["GatherV2"] = SetStaticInputs({2});
208-
set_attributes_map["Max"] = SetStaticInputs({1});
209-
set_attributes_map["Mean"] = SetStaticInputs({1});
210-
set_attributes_map["Min"] = SetStaticInputs({1});
211206
set_attributes_map["MirrorPad"] = SetStaticInputs({1});
212207
set_attributes_map["NonMaxSuppressionV4"] = SetStaticInputs({2, 3, 4});
213208
set_attributes_map["OneHot"] = SetStaticInputs({1});
214209
set_attributes_map["Pad"] = SetStaticInputs({1});
215210
set_attributes_map["PadV2"] = SetStaticInputs({1, 2});
216-
set_attributes_map["Prod"] = SetStaticInputs({1});
217-
218211
set_attributes_map["QuantizeAndDequantizeV2"] = SetStaticInputs({1, 2});
219212
set_attributes_map["QuantizedConcat"] = [](Node* n) {
220213
SetStaticInputs(n, {0}); // the axis
@@ -242,7 +235,6 @@ const std::map<std::string, SetAttributesFunction>& GetAttributeSetters() {
242235
set_attributes_map["Split"] = SetStaticInputs({0});
243236
set_attributes_map["SplitV"] = SetStaticInputs({1, 2});
244237
set_attributes_map["StridedSlice"] = SetStaticInputs({1, 2, 3});
245-
set_attributes_map["Sum"] = SetStaticInputs({1});
246238
set_attributes_map["TopKV2"] = SetStaticInputs({1});
247239
set_attributes_map["Tile"] = SetStaticInputs({1});
248240
set_attributes_map["Transpose"] = SetStaticInputs({1});

0 commit comments

Comments
 (0)