Skip to content

Commit 2f8c3e7

Browse files
Merge branch 'master' into r0.13
2 parents f0b86b9 + 457af5b commit 2f8c3e7

File tree

5 files changed

+401
-112
lines changed

5 files changed

+401
-112
lines changed

python/ngraph_bridge/__init__.in.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tensorflow.python.framework import ops
3535

3636
from tensorflow.core.protobuf import rewriter_config_pb2
37+
from tensorflow.python.framework import load_library
3738

3839
import ctypes
3940

@@ -92,8 +93,9 @@
9293
(TF_INSTALLED_VER[1] == TF_NEEDED_VER[1]) and \
9394
((TF_INSTALLED_VER[2].split('-'))[0] == (TF_NEEDED_VER[2].split('-'))[0]):
9495
libpath = os.path.dirname(__file__)
95-
ngraph_bridge_lib = ctypes.cdll.LoadLibrary(
96-
os.path.join(libpath, 'libngraph_bridge.' + ext))
96+
full_lib_path = os.path.join(libpath, 'libngraph_bridge.' + ext)
97+
_ = load_library.load_op_library(full_lib_path)
98+
ngraph_bridge_lib = ctypes.cdll.LoadLibrary(full_lib_path)
9799
else:
98100
raise ValueError(
99101
"Error: Installed TensorFlow version {0}\nnGraph bridge built with: {1}"
@@ -227,4 +229,3 @@ def get_disabled_ops():
227229
"nGraph bridge built with Grappler: " + str(ngraph_bridge_lib.ngraph_tf_is_grappler_enabled()) + "\n" \
228230
"nGraph bridge built with Variables and Optimizers Enablement: " \
229231
+ str(ngraph_bridge_lib.ngraph_tf_are_variables_enabled())
230-

src/ngraph_builder.cc

Lines changed: 127 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,9 @@ static Status TranslateBatchMatMulOp(
750750
shared_ptr<ng::Node> ng_lhs, ng_rhs;
751751
TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_lhs, &ng_rhs));
752752

753+
std::string backend_name;
754+
TF_RETURN_IF_ERROR(ngraph_bridge::GetNodeBackend(op, &backend_name));
755+
753756
auto ng_lhs_shape = ng_lhs->get_shape();
754757
auto ng_rhs_shape = ng_rhs->get_shape();
755758

@@ -781,77 +784,138 @@ static Status TranslateBatchMatMulOp(
781784

782785
auto ng_lhs_axes = out_axes;
783786
auto ng_rhs_axes = out_axes;
784-
if (tf_adj_x) {
785-
ng_lhs_axes.push_back(n_dims - 1);
786-
ng_lhs_axes.push_back(n_dims - 2);
787-
ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes);
788-
}
789-
if (tf_adj_y) {
790-
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2);
791-
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1);
792-
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
793-
} else {
794-
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1);
795-
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2);
796-
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
797-
}
798787

799-
ng_lhs_shape = ng_lhs->get_shape();
800-
ng_rhs_shape = ng_rhs->get_shape();
801-
802-
if (ng_lhs_shape[n_dims - 1] != ng_rhs_shape[0]) {
803-
return errors::InvalidArgument(
804-
"The last dimension of ng_lhs and the first dimension of ng_rhs "
805-
"should have the same size");
806-
}
807-
if (n_dims == 2) {
808-
SaveNgOp(ng_op_map, op->name(),
809-
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs));
810-
} else {
811-
auto output_shape = ng_lhs_shape;
812-
output_shape[n_dims - 1] = ng_rhs_shape[1];
813-
auto dot_output =
814-
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs);
815-
size_t compound_size = 1;
816-
for (int i = 0; i < out_axes.size(); i++) {
817-
compound_size *= output_shape[i];
788+
// Get the backend name, if the backend is CPU and n_dims >= 3
789+
// then use the BatchMatMul op supported by nGraph
790+
if (n_dims >= 3 && backend_name == "CPU") {
791+
// Transpose X if AdjX = true
792+
if (tf_adj_x) {
793+
ng_lhs_axes.push_back(n_dims - 1);
794+
ng_lhs_axes.push_back(n_dims - 2);
795+
ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes);
796+
ng_lhs_shape = ng_lhs->get_shape();
797+
} else {
798+
ng_lhs_axes.push_back(n_dims - 2);
799+
ng_lhs_axes.push_back(n_dims - 1);
818800
}
819-
auto dot_axes = out_axes;
820-
dot_axes.push_back(n_dims - 2);
821-
dot_axes.push_back(n_dims - 1);
822-
for (int i = 0; i < out_axes.size(); i++) {
823-
dot_axes.push_back(n_dims + i);
801+
// Transpose Y if AdjY = true
802+
if (tf_adj_y) {
803+
ng_rhs_axes.push_back(n_dims - 1);
804+
ng_rhs_axes.push_back(n_dims - 2);
805+
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
806+
ng_rhs_shape = ng_rhs->get_shape();
807+
} else {
808+
ng_rhs_axes.push_back(n_dims - 2);
809+
ng_rhs_axes.push_back(n_dims - 1);
824810
}
825-
ng::Shape dot_shape = {compound_size, ng_lhs_shape[n_dims - 2],
826-
ng_rhs_shape[1], compound_size};
827-
std::shared_ptr<ng::Node> dot_reshape;
811+
828812
if (n_dims == 3) {
829-
dot_reshape = dot_output;
813+
SaveNgOp(ng_op_map, op->name(), ConstructNgNode<ngraph::op::BatchMatMul>(
814+
op->name(), ng_lhs, ng_rhs));
830815
} else {
831-
dot_reshape = ConstructNgNode<ngraph::op::Reshape>(op->name(), dot_output,
832-
dot_axes, dot_shape);
816+
// Find the compound size for dim1 so as to reshape to 3D
817+
size_t compound_size = 1;
818+
for (int i = 0; i < out_axes.size(); i++) {
819+
compound_size *= ng_lhs_shape[i];
820+
}
821+
822+
ng::Shape tmp_lhs_shape = {compound_size, ng_lhs_shape[n_dims - 2],
823+
ng_lhs_shape[n_dims - 1]};
824+
ng::Shape tmp_rhs_shape = {compound_size, ng_rhs_shape[n_dims - 2],
825+
ng_rhs_shape[n_dims - 1]};
826+
827+
auto output_shape = ng_lhs_shape;
828+
output_shape[n_dims - 1] = ng_rhs_shape[n_dims - 1];
829+
ng::AxisVector tmp_axes = {0, 1, 2};
830+
831+
std::shared_ptr<ng::Node> lhs_reshape =
832+
ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_lhs, ng_lhs_axes,
833+
tmp_lhs_shape);
834+
std::shared_ptr<ng::Node> rhs_reshape =
835+
ConstructNgNode<ngraph::op::Reshape>(op->name(), ng_rhs, ng_rhs_axes,
836+
tmp_rhs_shape);
837+
std::shared_ptr<ng::Node> batchmatmul =
838+
ConstructNgNode<ngraph::op::BatchMatMul>(op->name(), lhs_reshape,
839+
rhs_reshape);
840+
SaveNgOp(ng_op_map, op->name(),
841+
ConstructNgNode<ngraph::op::Reshape>(op->name(), batchmatmul,
842+
tmp_axes, output_shape));
833843
}
834-
ng::Shape tmp_shape = {1, ng_lhs_shape[n_dims - 2], ng_rhs_shape[1]};
835-
vector<shared_ptr<ngraph::Node>> tmp_tensors;
836-
for (size_t i = 0; i < dot_shape[0]; i++) {
837-
const std::vector<size_t> lower_bound{i, 0, 0, i};
838-
const std::vector<size_t> upper_bound{i + 1, dot_shape[1], dot_shape[2],
839-
i + 1};
840-
auto slice_out = ConstructNgNode<ngraph::op::Slice>(
841-
op->name(), dot_reshape, lower_bound, upper_bound);
842-
auto reshape_out = ConstructNgNode<ngraph::op::Reshape>(
843-
op->name(), slice_out, ng::AxisVector{0, 1, 2, 3}, tmp_shape);
844-
tmp_tensors.push_back(reshape_out);
844+
} else {
845+
if (tf_adj_x) {
846+
ng_lhs_axes.push_back(n_dims - 1);
847+
ng_lhs_axes.push_back(n_dims - 2);
848+
ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes);
845849
}
846-
auto concat_op =
847-
ConstructNgNode<ngraph::op::Concat>(op->name(), tmp_tensors, 0);
848-
if (n_dims == 3) {
849-
SaveNgOp(ng_op_map, op->name(), concat_op);
850+
if (tf_adj_y) {
851+
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2);
852+
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1);
853+
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
854+
} else {
855+
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1);
856+
ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2);
857+
ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes);
858+
}
859+
860+
ng_lhs_shape = ng_lhs->get_shape();
861+
ng_rhs_shape = ng_rhs->get_shape();
862+
863+
if (ng_lhs_shape[n_dims - 1] != ng_rhs_shape[0]) {
864+
return errors::InvalidArgument(
865+
"The last dimension of ng_lhs and the first dimension of ng_rhs "
866+
"should have the same size");
867+
}
868+
869+
if (n_dims == 2) {
870+
SaveNgOp(ng_op_map, op->name(),
871+
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs));
850872
} else {
851-
SaveNgOp(
852-
ng_op_map, op->name(),
853-
ConstructNgNode<ngraph::op::Reshape>(
854-
op->name(), concat_op, ng::AxisVector{0, 1, 2}, output_shape));
873+
auto output_shape = ng_lhs_shape;
874+
output_shape[n_dims - 1] = ng_rhs_shape[1];
875+
auto dot_output =
876+
ConstructNgNode<ngraph::op::Dot>(op->name(), ng_lhs, ng_rhs);
877+
878+
size_t compound_size = 1;
879+
for (int i = 0; i < out_axes.size(); i++) {
880+
compound_size *= output_shape[i];
881+
}
882+
auto dot_axes = out_axes;
883+
dot_axes.push_back(n_dims - 2);
884+
dot_axes.push_back(n_dims - 1);
885+
for (int i = 0; i < out_axes.size(); i++) {
886+
dot_axes.push_back(n_dims + i);
887+
}
888+
ng::Shape dot_shape = {compound_size, ng_lhs_shape[n_dims - 2],
889+
ng_rhs_shape[1], compound_size};
890+
std::shared_ptr<ng::Node> dot_reshape;
891+
if (n_dims == 3) {
892+
dot_reshape = dot_output;
893+
} else {
894+
dot_reshape = ConstructNgNode<ngraph::op::Reshape>(
895+
op->name(), dot_output, dot_axes, dot_shape);
896+
}
897+
ng::Shape tmp_shape = {1, ng_lhs_shape[n_dims - 2], ng_rhs_shape[1]};
898+
vector<shared_ptr<ngraph::Node>> tmp_tensors;
899+
for (size_t i = 0; i < dot_shape[0]; i++) {
900+
const std::vector<size_t> lower_bound{i, 0, 0, i};
901+
const std::vector<size_t> upper_bound{i + 1, dot_shape[1], dot_shape[2],
902+
i + 1};
903+
auto slice_out = ConstructNgNode<ngraph::op::Slice>(
904+
op->name(), dot_reshape, lower_bound, upper_bound);
905+
auto reshape_out = ConstructNgNode<ngraph::op::Reshape>(
906+
op->name(), slice_out, ng::AxisVector{0, 1, 2, 3}, tmp_shape);
907+
tmp_tensors.push_back(reshape_out);
908+
}
909+
auto concat_op =
910+
ConstructNgNode<ngraph::op::Concat>(op->name(), tmp_tensors, 0);
911+
if (n_dims == 3) {
912+
SaveNgOp(ng_op_map, op->name(), concat_op);
913+
} else {
914+
SaveNgOp(
915+
ng_op_map, op->name(),
916+
ConstructNgNode<ngraph::op::Reshape>(
917+
op->name(), concat_op, ng::AxisVector{0, 1, 2}, output_shape));
918+
}
855919
}
856920
}
857921
return Status::OK();

test/opexecuter.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,29 @@ void OpExecuter::ExecuteOnNGraph(vector<Tensor>& ngraph_outputs,
200200

201201
// Get Tensor input shapes and values from the const nodes
202202
int number_of_inputs = test_op->num_inputs();
203+
204+
// Create nGraph backend
205+
// If NGRAPH_TF_BACKEND is set create that backend
206+
// Else create backend of type ng_backend_name
207+
string ng_backend_type = ng_backend_name;
208+
const char* ng_backend_env_value = std::getenv("NGRAPH_TF_BACKEND");
209+
210+
if (ng_backend_env_value != nullptr) {
211+
string backend_env = std::string(ng_backend_env_value);
212+
bool valid_ngraph_tf_backend =
213+
!backend_env.empty() && BackendManager::IsSupportedBackend(backend_env);
214+
ASSERT_TRUE(valid_ngraph_tf_backend) << "NGRAPH_TF_BACKEND " << backend_env
215+
<< " is not a supported backend";
216+
ng_backend_type = backend_env;
217+
}
218+
219+
NGRAPH_VLOG(5) << " Creating NG Backend " << ng_backend_type;
220+
BackendManager::CreateBackend(ng_backend_type);
221+
auto backend = BackendManager::GetBackend(ng_backend_type);
222+
223+
// Add the _ngraph_backend attr to the node
224+
test_op->AddAttr("_ngraph_backend", ng_backend_type);
225+
203226
// TODO : Validate static_input_indexes < number_of_inputs
204227
vector<TensorShape> input_shapes;
205228
vector<DataType> input_dt;
@@ -328,24 +351,6 @@ void OpExecuter::ExecuteOnNGraph(vector<Tensor>& ngraph_outputs,
328351
NgraphSerialize("unit_test_" + test_op_type_ + ".json", ng_function);
329352
}
330353

331-
// Create nGraph backend
332-
// If NGRAPH_TF_BACKEND is set create that backend
333-
// Else create backend of type ng_backend_name
334-
string ng_backend_type = ng_backend_name;
335-
const char* ng_backend_env_value = std::getenv("NGRAPH_TF_BACKEND");
336-
if (ng_backend_env_value != nullptr) {
337-
string backend_env = std::string(ng_backend_env_value);
338-
bool valid_ngraph_tf_backend =
339-
!backend_env.empty() && BackendManager::IsSupportedBackend(backend_env);
340-
ASSERT_TRUE(valid_ngraph_tf_backend) << "NGRAPH_TF_BACKEND " << backend_env
341-
<< " is not a supported backend";
342-
ng_backend_type = backend_env;
343-
}
344-
345-
NGRAPH_VLOG(5) << " Creating NG Backend " << ng_backend_type;
346-
BackendManager::CreateBackend(ng_backend_type);
347-
auto backend = BackendManager::GetBackend(ng_backend_type);
348-
349354
// Allocate tensors for inputs
350355
vector<std::shared_ptr<ngraph::runtime::Tensor>> ng_ip_tensors;
351356
vector<std::shared_ptr<ngraph::runtime::Tensor>> ng_op_tensors;

0 commit comments

Comments
 (0)