@@ -2400,45 +2400,8 @@ static Status TranslateNonMaxSuppressionV4Op(
24002400 return Status::OK ();
24012401}
24022402
2403- static Status TranslateReduceOp (
2404- const Node* op, const std::vector<const Tensor*>& static_input_map,
2405- Builder::OpMap& ng_op_map,
2406- std::function<std::shared_ptr<ng::Node>(
2407- std::shared_ptr<ng::Node>, std::shared_ptr<ng::Node>, const bool )>
2408- create_ng_node) {
2409- shared_ptr<ng::Node> ng_input;
2410- TF_RETURN_IF_ERROR (GetInputNode (ng_op_map, op, 0 , &ng_input));
2411- bool tf_keep_dims;
2412- if (GetNodeAttr (op->attrs (), " keep_dims" , &tf_keep_dims) != Status::OK ()) {
2413- tf_keep_dims = false ;
2414- }
2415-
2416- std::vector<int64> axes;
2417- TF_RETURN_IF_ERROR (GetStaticInputVector (op, 1 , static_input_map, &axes));
2418-
2419- ng::Shape input_shape = ng_input->get_shape ();
2420- size_t input_rank = input_shape.size ();
2421-
2422- TF_RETURN_IF_ERROR (CheckAxisDimInRange (axes, input_rank));
2423-
2424- std::vector<size_t > ng_reduction_axes_vect (axes.size ());
2425- std::transform (
2426- axes.begin (), axes.end (), ng_reduction_axes_vect.begin (),
2427- [input_rank](int idx) { return idx + (idx < 0 ? (int )input_rank : 0 ); });
2428- auto ng_reduction_axes = ConstructNgNode<ng::opset3::Constant>(
2429- op->name (), ng::element::i64 , ng::Shape{ng_reduction_axes_vect.size ()},
2430- ng_reduction_axes_vect);
2431-
2432- std::shared_ptr<ng::Node> ng_node =
2433- create_ng_node (ng_input, ng_reduction_axes, tf_keep_dims);
2434- Builder::SetTracingInfo (op->name (), ng_node);
2435-
2436- SaveNgOp (ng_op_map, op->name (), ng_node);
2437- return Status::OK ();
2438- }
2439-
24402403template <typename T>
2441- static Status TranslateDirectReduceOp (
2404+ static Status TranslateReduceOp (
24422405 const Node* op, const std::vector<const Tensor*>& static_input_map,
24432406 Builder::OpMap& ng_op_map) {
24442407 // ensure its either an arithmetic or a logical reduction
@@ -2448,13 +2411,19 @@ static Status TranslateDirectReduceOp(
24482411 " Expected node to be either a valid logical or arithmetic reduction "
24492412 " type" );
24502413 }
2451- return TranslateReduceOp (
2452- op, static_input_map, ng_op_map,
2453- [&op](std::shared_ptr<ng::Node> ng_input,
2454- std::shared_ptr<ng::Node> ng_reduction_axes, const bool keep_dims) {
2455- return ConstructNgNode<T>(op->name (), ng_input, ng_reduction_axes,
2456- keep_dims);
2457- });
2414+
2415+ shared_ptr<ng::Node> ng_input, ng_reduction_indices;
2416+ TF_RETURN_IF_ERROR (
2417+ GetInputNodes (ng_op_map, op, &ng_input, &ng_reduction_indices));
2418+ bool keep_dims;
2419+ if (GetNodeAttr (op->attrs (), " keep_dims" , &keep_dims) != Status::OK ()) {
2420+ keep_dims = false ;
2421+ }
2422+
2423+ std::shared_ptr<ng::Node> ng_node =
2424+ ConstructNgNode<T>(op->name (), ng_input, ng_reduction_indices, keep_dims);
2425+ SaveNgOp (ng_op_map, op->name (), ng_node);
2426+ return Status::OK ();
24582427}
24592428
24602429static Status TranslateOneHotOp (
@@ -3908,8 +3877,8 @@ const static std::map<
39083877 {" Add" , TranslateBinaryOp<ngraph::opset3::Add>},
39093878 {" AddN" , TranslateAddNOp},
39103879 {" AddV2" , TranslateBinaryOp<ngraph::opset3::Add>},
3911- {" Any" , TranslateDirectReduceOp <ng::opset3::ReduceLogicalOr>},
3912- {" All" , TranslateDirectReduceOp <ng::opset3::ReduceLogicalAnd>},
3880+ {" Any" , TranslateReduceOp <ng::opset3::ReduceLogicalOr>},
3881+ {" All" , TranslateReduceOp <ng::opset3::ReduceLogicalAnd>},
39133882 {" ArgMax" , TranslateArgMinMaxOp<ng::op::ArgMax>},
39143883 {" ArgMin" , TranslateArgMinMaxOp<ng::op::ArgMin>},
39153884 {" Asin" , TranslateUnaryOp<ngraph::opset3::Asin>},
@@ -3961,13 +3930,13 @@ const static std::map<
39613930 {" LogicalNot" , TranslateUnaryOp<ngraph::opset3::LogicalNot>},
39623931 {" LogicalOr" , TranslateBinaryOp<ngraph::opset3::LogicalOr>},
39633932 {" MatMul" , TranslateMatMulOp},
3964- {" Max" , TranslateDirectReduceOp <ng::opset3::ReduceMax>},
3933+ {" Max" , TranslateReduceOp <ng::opset3::ReduceMax>},
39653934 {" Maximum" , TranslateBinaryOp<ngraph::opset3::Maximum>},
39663935 {" MaxPool" , TranslateMaxPoolOp},
39673936 {" MaxPool3D" , TranslateMaxPool3DOp},
39683937 {" NonMaxSuppressionV4" , TranslateNonMaxSuppressionV4Op},
3969- {" Mean" , TranslateDirectReduceOp <ng::opset3::ReduceMean>},
3970- {" Min" , TranslateDirectReduceOp <ng::opset3::ReduceMin>},
3938+ {" Mean" , TranslateReduceOp <ng::opset3::ReduceMean>},
3939+ {" Min" , TranslateReduceOp <ng::opset3::ReduceMin>},
39713940 {" Minimum" , TranslateBinaryOp<ngraph::opset3::Minimum>},
39723941 {" MirrorPad" , TranslatePadOp},
39733942 {" Mul" , TranslateBinaryOp<ngraph::opset3::Multiply>},
@@ -3985,7 +3954,7 @@ const static std::map<
39853954 {" Pow" , TranslateBinaryOp<ngraph::opset3::Power>},
39863955 // PreventGradient is just Identity in data-flow terms, so reuse that.
39873956 {" PreventGradient" , TranslateIdentityOp},
3988- {" Prod" , TranslateDirectReduceOp <ng::opset3::ReduceProd>},
3957+ {" Prod" , TranslateReduceOp <ng::opset3::ReduceProd>},
39893958 {" QuantizeAndDequantizeV2" , TranslateQuantizeAndDequantizeV2Op},
39903959 {" QuantizedAvgPool" , TranslateQuantizedAvgPoolOp},
39913960 {" QuantizedConcat" , TranslateQuantizedConcatOp},
@@ -4029,7 +3998,7 @@ const static std::map<
40293998 {" Squeeze" , TranslateSqueezeOp},
40303999 {" StridedSlice" , TranslateStridedSliceOp},
40314000 {" Sub" , TranslateBinaryOp<ngraph::opset3::Subtract>},
4032- {" Sum" , TranslateDirectReduceOp <ng::opset3::ReduceSum>},
4001+ {" Sum" , TranslateReduceOp <ng::opset3::ReduceSum>},
40334002 {" Tan" , TranslateUnaryOp<ngraph::opset3::Tan>},
40344003 {" Tanh" , TranslateUnaryOp<ngraph::opset3::Tanh>},
40354004 {" Tile" , TranslateTileOp},
0 commit comments