@@ -750,6 +750,9 @@ static Status TranslateBatchMatMulOp(
750750 shared_ptr<ng::Node> ng_lhs, ng_rhs;
751751 TF_RETURN_IF_ERROR (GetInputNodes (ng_op_map, op, &ng_lhs, &ng_rhs));
752752
753+ std::string backend_name;
754+ TF_RETURN_IF_ERROR (ngraph_bridge::GetNodeBackend (op, &backend_name));
755+
753756 auto ng_lhs_shape = ng_lhs->get_shape ();
754757 auto ng_rhs_shape = ng_rhs->get_shape ();
755758
@@ -781,77 +784,138 @@ static Status TranslateBatchMatMulOp(
781784
782785 auto ng_lhs_axes = out_axes;
783786 auto ng_rhs_axes = out_axes;
784- if (tf_adj_x) {
785- ng_lhs_axes.push_back (n_dims - 1 );
786- ng_lhs_axes.push_back (n_dims - 2 );
787- ng_lhs = ng::builder::numpy_transpose (ng_lhs, ng_lhs_axes);
788- }
789- if (tf_adj_y) {
790- ng_rhs_axes.insert (ng_rhs_axes.begin (), n_dims - 2 );
791- ng_rhs_axes.insert (ng_rhs_axes.begin (), n_dims - 1 );
792- ng_rhs = ng::builder::numpy_transpose (ng_rhs, ng_rhs_axes);
793- } else {
794- ng_rhs_axes.insert (ng_rhs_axes.begin (), n_dims - 1 );
795- ng_rhs_axes.insert (ng_rhs_axes.begin (), n_dims - 2 );
796- ng_rhs = ng::builder::numpy_transpose (ng_rhs, ng_rhs_axes);
797- }
798787
799- ng_lhs_shape = ng_lhs->get_shape ();
800- ng_rhs_shape = ng_rhs->get_shape ();
801-
802- if (ng_lhs_shape[n_dims - 1 ] != ng_rhs_shape[0 ]) {
803- return errors::InvalidArgument (
804- " The last dimension of ng_lhs and the first dimension of ng_rhs "
805- " should have the same size" );
806- }
807- if (n_dims == 2 ) {
808- SaveNgOp (ng_op_map, op->name (),
809- ConstructNgNode<ngraph::op::Dot>(op->name (), ng_lhs, ng_rhs));
810- } else {
811- auto output_shape = ng_lhs_shape;
812- output_shape[n_dims - 1 ] = ng_rhs_shape[1 ];
813- auto dot_output =
814- ConstructNgNode<ngraph::op::Dot>(op->name (), ng_lhs, ng_rhs);
815- size_t compound_size = 1 ;
816- for (int i = 0 ; i < out_axes.size (); i++) {
817- compound_size *= output_shape[i];
788+ // Get the backend name, if the backend is CPU and n_dims >= 3
789+ // then use the BatchMatMul op supported by nGraph
790+ if (n_dims >= 3 && backend_name == " CPU" ) {
791+ // Transpose X if AdjX = true
792+ if (tf_adj_x) {
793+ ng_lhs_axes.push_back (n_dims - 1 );
794+ ng_lhs_axes.push_back (n_dims - 2 );
795+ ng_lhs = ng::builder::numpy_transpose (ng_lhs, ng_lhs_axes);
796+ ng_lhs_shape = ng_lhs->get_shape ();
797+ } else {
798+ ng_lhs_axes.push_back (n_dims - 2 );
799+ ng_lhs_axes.push_back (n_dims - 1 );
818800 }
819- auto dot_axes = out_axes;
820- dot_axes.push_back (n_dims - 2 );
821- dot_axes.push_back (n_dims - 1 );
822- for (int i = 0 ; i < out_axes.size (); i++) {
823- dot_axes.push_back (n_dims + i);
801+ // Transpose Y if AdjY = true
802+ if (tf_adj_y) {
803+ ng_rhs_axes.push_back (n_dims - 1 );
804+ ng_rhs_axes.push_back (n_dims - 2 );
805+ ng_rhs = ng::builder::numpy_transpose (ng_rhs, ng_rhs_axes);
806+ ng_rhs_shape = ng_rhs->get_shape ();
807+ } else {
808+ ng_rhs_axes.push_back (n_dims - 2 );
809+ ng_rhs_axes.push_back (n_dims - 1 );
824810 }
825- ng::Shape dot_shape = {compound_size, ng_lhs_shape[n_dims - 2 ],
826- ng_rhs_shape[1 ], compound_size};
827- std::shared_ptr<ng::Node> dot_reshape;
811+
828812 if (n_dims == 3 ) {
829- dot_reshape = dot_output;
813+ SaveNgOp (ng_op_map, op->name (), ConstructNgNode<ngraph::op::BatchMatMul>(
814+ op->name (), ng_lhs, ng_rhs));
830815 } else {
831- dot_reshape = ConstructNgNode<ngraph::op::Reshape>(op->name (), dot_output,
832- dot_axes, dot_shape);
816+ // Find the compound size for dim1 so as to reshape to 3D
817+ size_t compound_size = 1 ;
818+ for (int i = 0 ; i < out_axes.size (); i++) {
819+ compound_size *= ng_lhs_shape[i];
820+ }
821+
822+ ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2 ],
823+ ng_lhs_shape[n_dims - 1 ]};
824+ ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2 ],
825+ ng_rhs_shape[n_dims - 1 ]};
826+
827+ auto output_shape = ng_lhs_shape;
828+ output_shape[n_dims - 1 ] = ng_rhs_shape[n_dims - 1 ];
829+ ng::AxisVector tmp_axes = {0 , 1 , 2 };
830+
831+ std::shared_ptr<ng::Node> lhs_reshape =
832+ ConstructNgNode<ngraph::op::Reshape>(op->name (), ng_lhs, ng_lhs_axes,
833+ tmp_lhs_shape);
834+ std::shared_ptr<ng::Node> rhs_reshape =
835+ ConstructNgNode<ngraph::op::Reshape>(op->name (), ng_rhs, ng_rhs_axes,
836+ tmp_rhs_shape);
837+ std::shared_ptr<ng::Node> batchmatmul =
838+ ConstructNgNode<ngraph::op::BatchMatMul>(op->name (), lhs_reshape,
839+ rhs_reshape);
840+ SaveNgOp (ng_op_map, op->name (),
841+ ConstructNgNode<ngraph::op::Reshape>(op->name (), batchmatmul,
842+ tmp_axes, output_shape));
833843 }
834- ng::Shape tmp_shape = {1 , ng_lhs_shape[n_dims - 2 ], ng_rhs_shape[1 ]};
835- vector<shared_ptr<ngraph::Node>> tmp_tensors;
836- for (size_t i = 0 ; i < dot_shape[0 ]; i++) {
837- const std::vector<size_t > lower_bound{i, 0 , 0 , i};
838- const std::vector<size_t > upper_bound{i + 1 , dot_shape[1 ], dot_shape[2 ],
839- i + 1 };
840- auto slice_out = ConstructNgNode<ngraph::op::Slice>(
841- op->name (), dot_reshape, lower_bound, upper_bound);
842- auto reshape_out = ConstructNgNode<ngraph::op::Reshape>(
843- op->name (), slice_out, ng::AxisVector{0 , 1 , 2 , 3 }, tmp_shape);
844- tmp_tensors.push_back (reshape_out);
844+ } else {
845+ if (tf_adj_x) {
846+ ng_lhs_axes.push_back (n_dims - 1 );
847+ ng_lhs_axes.push_back (n_dims - 2 );
848+ ng_lhs = ng::builder::numpy_transpose (ng_lhs, ng_lhs_axes);
845849 }
846- auto concat_op =
847- ConstructNgNode<ngraph::op::Concat>(op->name (), tmp_tensors, 0 );
848- if (n_dims == 3 ) {
849- SaveNgOp (ng_op_map, op->name (), concat_op);
850+ if (tf_adj_y) {
851+ ng_rhs_axes.insert (ng_rhs_axes.begin (), n_dims - 2 );
852+ ng_rhs_axes.insert (ng_rhs_axes.begin (), n_dims - 1 );
853+ ng_rhs = ng::builder::numpy_transpose (ng_rhs, ng_rhs_axes);
854+ } else {
855+ ng_rhs_axes.insert (ng_rhs_axes.begin (), n_dims - 1 );
856+ ng_rhs_axes.insert (ng_rhs_axes.begin (), n_dims - 2 );
857+ ng_rhs = ng::builder::numpy_transpose (ng_rhs, ng_rhs_axes);
858+ }
859+
860+ ng_lhs_shape = ng_lhs->get_shape ();
861+ ng_rhs_shape = ng_rhs->get_shape ();
862+
863+ if (ng_lhs_shape[n_dims - 1 ] != ng_rhs_shape[0 ]) {
864+ return errors::InvalidArgument (
865+ " The last dimension of ng_lhs and the first dimension of ng_rhs "
866+ " should have the same size" );
867+ }
868+
869+ if (n_dims == 2 ) {
870+ SaveNgOp (ng_op_map, op->name (),
871+ ConstructNgNode<ngraph::op::Dot>(op->name (), ng_lhs, ng_rhs));
850872 } else {
851- SaveNgOp (
852- ng_op_map, op->name (),
853- ConstructNgNode<ngraph::op::Reshape>(
854- op->name (), concat_op, ng::AxisVector{0 , 1 , 2 }, output_shape));
873+ auto output_shape = ng_lhs_shape;
874+ output_shape[n_dims - 1 ] = ng_rhs_shape[1 ];
875+ auto dot_output =
876+ ConstructNgNode<ngraph::op::Dot>(op->name (), ng_lhs, ng_rhs);
877+
878+ size_t compound_size = 1 ;
879+ for (int i = 0 ; i < out_axes.size (); i++) {
880+ compound_size *= output_shape[i];
881+ }
882+ auto dot_axes = out_axes;
883+ dot_axes.push_back (n_dims - 2 );
884+ dot_axes.push_back (n_dims - 1 );
885+ for (int i = 0 ; i < out_axes.size (); i++) {
886+ dot_axes.push_back (n_dims + i);
887+ }
888+ ng::Shape dot_shape = {compound_size, ng_lhs_shape[n_dims - 2 ],
889+ ng_rhs_shape[1 ], compound_size};
890+ std::shared_ptr<ng::Node> dot_reshape;
891+ if (n_dims == 3 ) {
892+ dot_reshape = dot_output;
893+ } else {
894+ dot_reshape = ConstructNgNode<ngraph::op::Reshape>(
895+ op->name (), dot_output, dot_axes, dot_shape);
896+ }
897+ ng::Shape tmp_shape = {1 , ng_lhs_shape[n_dims - 2 ], ng_rhs_shape[1 ]};
898+ vector<shared_ptr<ngraph::Node>> tmp_tensors;
899+ for (size_t i = 0 ; i < dot_shape[0 ]; i++) {
900+ const std::vector<size_t > lower_bound{i, 0 , 0 , i};
901+ const std::vector<size_t > upper_bound{i + 1 , dot_shape[1 ], dot_shape[2 ],
902+ i + 1 };
903+ auto slice_out = ConstructNgNode<ngraph::op::Slice>(
904+ op->name (), dot_reshape, lower_bound, upper_bound);
905+ auto reshape_out = ConstructNgNode<ngraph::op::Reshape>(
906+ op->name (), slice_out, ng::AxisVector{0 , 1 , 2 , 3 }, tmp_shape);
907+ tmp_tensors.push_back (reshape_out);
908+ }
909+ auto concat_op =
910+ ConstructNgNode<ngraph::op::Concat>(op->name (), tmp_tensors, 0 );
911+ if (n_dims == 3 ) {
912+ SaveNgOp (ng_op_map, op->name (), concat_op);
913+ } else {
914+ SaveNgOp (
915+ ng_op_map, op->name (),
916+ ConstructNgNode<ngraph::op::Reshape>(
917+ op->name (), concat_op, ng::AxisVector{0 , 1 , 2 }, output_shape));
918+ }
855919 }
856920 }
857921 return Status::OK ();
0 commit comments