@@ -2079,45 +2079,8 @@ static Status TranslateNonMaxSuppressionV4Op(
20792079 return Status::OK ();
20802080}
20812081
2082- static Status TranslateReduceOp (
2083- const Node* op, const std::vector<const Tensor*>& static_input_map,
2084- Builder::OpMap& ng_op_map,
2085- std::function<ng::Output<ng::Node>(ng::Output<ng::Node>,
2086- ng::Output<ng::Node>, const bool )>
2087- create_ng_node) {
2088- ng::Output<ng::Node> ng_input;
2089- TF_RETURN_IF_ERROR (GetInputNode (ng_op_map, op, 0 , ng_input));
2090- bool tf_keep_dims;
2091- if (GetNodeAttr (op->attrs (), " keep_dims" , &tf_keep_dims) != Status::OK ()) {
2092- tf_keep_dims = false ;
2093- }
2094-
2095- std::vector<int64> axes;
2096- TF_RETURN_IF_ERROR (GetStaticInputVector (op, 1 , static_input_map, &axes));
2097-
2098- ng::Shape input_shape = ng_input.get_shape ();
2099- size_t input_rank = input_shape.size ();
2100-
2101- TF_RETURN_IF_ERROR (CheckAxisDimInRange (axes, input_rank));
2102-
2103- std::vector<size_t > ng_reduction_axes_vect (axes.size ());
2104- std::transform (
2105- axes.begin (), axes.end (), ng_reduction_axes_vect.begin (),
2106- [input_rank](int idx) { return idx + (idx < 0 ? (int )input_rank : 0 ); });
2107- auto ng_reduction_axes = ConstructNgNode<opset::Constant>(
2108- op->name (), ng::element::i64 , ng::Shape{ng_reduction_axes_vect.size ()},
2109- ng_reduction_axes_vect);
2110-
2111- ng::Output<ng::Node> ng_node =
2112- create_ng_node (ng_input, ng_reduction_axes, tf_keep_dims);
2113- Builder::SetTracingInfo (op->name (), ng_node);
2114-
2115- SaveNgOp (ng_op_map, op->name (), ng_node);
2116- return Status::OK ();
2117- }
2118-
21192082template <typename T>
2120- static Status TranslateDirectReduceOp (
2083+ static Status TranslateReduceOp (
21212084 const Node* op, const std::vector<const Tensor*>& static_input_map,
21222085 Builder::OpMap& ng_op_map) {
21232086 // ensure its either an arithmetic or a logical reduction
@@ -2127,13 +2090,19 @@ static Status TranslateDirectReduceOp(
21272090 " Expected node to be either a valid logical or arithmetic reduction "
21282091 " type" );
21292092 }
2130- return TranslateReduceOp (
2131- op, static_input_map, ng_op_map,
2132- [&op](ng::Output<ng::Node> ng_input,
2133- ng::Output<ng::Node> ng_reduction_axes, const bool keep_dims) {
2134- return ConstructNgNode<T>(op->name (), ng_input, ng_reduction_axes,
2135- keep_dims);
2136- });
2093+
2094+ shared_ptr<ng::Node> ng_input, ng_reduction_indices;
2095+ TF_RETURN_IF_ERROR (
2096+ GetInputNodes (ng_op_map, op, &ng_input, &ng_reduction_indices));
2097+ bool keep_dims;
2098+ if (GetNodeAttr (op->attrs (), " keep_dims" , &keep_dims) != Status::OK ()) {
2099+ keep_dims = false ;
2100+ }
2101+
2102+ auto ng_node =
2103+ ConstructNgNode<T>(op->name (), ng_input, ng_reduction_indices, keep_dims);
2104+ SaveNgOp (ng_op_map, op->name (), ng_node);
2105+ return Status::OK ();
21372106}
21382107
21392108static Status TranslateOneHotOp (
@@ -3002,8 +2971,8 @@ const static std::map<
30022971 {" Add" , TranslateBinaryOp<opset::Add>},
30032972 {" AddN" , TranslateAddNOp},
30042973 {" AddV2" , TranslateBinaryOp<opset::Add>},
3005- {" Any" , TranslateDirectReduceOp <opset::ReduceLogicalOr>},
3006- {" All" , TranslateDirectReduceOp <opset::ReduceLogicalAnd>},
2974+ {" Any" , TranslateReduceOp <opset::ReduceLogicalOr>},
2975+ {" All" , TranslateReduceOp <opset::ReduceLogicalAnd>},
30072976 {" ArgMax" , TranslateArgMaxOp},
30082977 {" ArgMin" , TranslateArgMinOp},
30092978 {" Asin" , TranslateUnaryOp<opset::Asin>},
@@ -3053,13 +3022,13 @@ const static std::map<
30533022 {" LogicalNot" , TranslateUnaryOp<opset::LogicalNot>},
30543023 {" LogicalOr" , TranslateBinaryOp<opset::LogicalOr>},
30553024 {" MatMul" , TranslateMatMulOp},
3056- {" Max" , TranslateDirectReduceOp <opset::ReduceMax>},
3025+ {" Max" , TranslateReduceOp <opset::ReduceMax>},
30573026 {" Maximum" , TranslateBinaryOp<opset::Maximum>},
30583027 {" MaxPool" , TranslateMaxPoolOp},
30593028 {" MaxPool3D" , TranslateMaxPool3DOp},
30603029 {" NonMaxSuppressionV4" , TranslateNonMaxSuppressionV4Op},
3061- {" Mean" , TranslateDirectReduceOp <opset::ReduceMean>},
3062- {" Min" , TranslateDirectReduceOp <opset::ReduceMin>},
3030+ {" Mean" , TranslateReduceOp <opset::ReduceMean>},
3031+ {" Min" , TranslateReduceOp <opset::ReduceMin>},
30633032 {" Minimum" , TranslateBinaryOp<opset::Minimum>},
30643033 {" MirrorPad" , TranslatePadOp},
30653034 {" Mul" , TranslateBinaryOp<opset::Multiply>},
@@ -3077,7 +3046,7 @@ const static std::map<
30773046 {" Pow" , TranslateBinaryOp<opset::Power>},
30783047 // PreventGradient is just Identity in dataflow terms, so reuse that.
30793048 {" PreventGradient" , TranslateIdentityOp},
3080- {" Prod" , TranslateDirectReduceOp <opset::ReduceProd>},
3049+ {" Prod" , TranslateReduceOp <opset::ReduceProd>},
30813050 {" Rank" , TranslateRankOp},
30823051 {" RealDiv" , TranslateBinaryOp<opset::Divide>},
30833052 {" Reciprocal" , TranslateReciprocalOp},
@@ -3106,7 +3075,7 @@ const static std::map<
31063075 {" Squeeze" , TranslateSqueezeOp},
31073076 {" StridedSlice" , TranslateStridedSliceOp},
31083077 {" Sub" , TranslateBinaryOp<opset::Subtract>},
3109- {" Sum" , TranslateDirectReduceOp <opset::ReduceSum>},
3078+ {" Sum" , TranslateReduceOp <opset::ReduceSum>},
31103079 {" Tan" , TranslateUnaryOp<opset::Tan>},
31113080 {" Tanh" , TranslateUnaryOp<opset::Tanh>},
31123081 {" Tile" , TranslateTileOp},
0 commit comments