From 7becb4c220fcb8809edfae3b74fa68765bc1c69a Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Tue, 20 Oct 2020 17:04:21 -0700 Subject: [PATCH] Move nGraph rewrite pass POST_PLACEMENT --- ngraph_bridge/CMakeLists.txt | 3 +- ngraph_bridge/ngraph_mark_for_clustering.cc | 48 +++++++++++++++++++-- ngraph_bridge/ngraph_rewrite_pass.cc | 3 +- test/tests_linux_cpu.txt | 4 +- 4 files changed, 51 insertions(+), 7 deletions(-) diff --git a/ngraph_bridge/CMakeLists.txt b/ngraph_bridge/CMakeLists.txt index 3f34c294c..303813986 100644 --- a/ngraph_bridge/CMakeLists.txt +++ b/ngraph_bridge/CMakeLists.txt @@ -52,13 +52,14 @@ set(SRC tf_graphcycles.cc tf_deadness_analysis.cc version.cc + grappler/ngraph_add_identityn.cc ) message(STATUS "NGRAPH_TF_USE_GRAPPLER_OPTIMIZER: ${NGRAPH_TF_USE_GRAPPLER_OPTIMIZER}") if(NGRAPH_TF_USE_GRAPPLER_OPTIMIZER) list(REMOVE_ITEM SRC ngraph_rewrite_pass.cc) list(APPEND SRC grappler/ngraph_optimizer.cc) - list(APPEND SRC grappler/ngraph_add_identityn.cc) + # list(APPEND SRC grappler/ngraph_add_identityn.cc) add_definitions(-DNGRAPH_TF_USE_GRAPPLER_OPTIMIZER) endif() diff --git a/ngraph_bridge/ngraph_mark_for_clustering.cc b/ngraph_bridge/ngraph_mark_for_clustering.cc index 60bbaad09..61bdb055f 100644 --- a/ngraph_bridge/ngraph_mark_for_clustering.cc +++ b/ngraph_bridge/ngraph_mark_for_clustering.cc @@ -14,9 +14,11 @@ * limitations under the License. *******************************************************************************/ +#include "tensorflow/core/common_runtime/build_graph_options.h" #include "tensorflow/core/graph/graph.h" #include "ngraph_bridge/default_opset.h" +#include "ngraph_bridge/grappler/ngraph_add_identityn.h" #include "ngraph_bridge/ngraph_api.h" #include "ngraph_bridge/ngraph_backend_manager.h" #include "ngraph_bridge/ngraph_mark_for_clustering.h" @@ -739,8 +741,7 @@ GetTFToNgOpMap() { // // Main entry point for the marking pass. // -Status MarkForClustering(Graph* graph, - const std::set skip_these_nodes) { +Status MarkForClustering(Graph* graph, std::set skip_these_nodes) { const TypeConstraintMap& type_constraint_map = GetTypeConstraintMap(); // confirmation_function_map is non-const unlike the other maps @@ -800,9 +801,50 @@ Status MarkForClustering(Graph* graph, vector nodes_marked_for_clustering; shared_ptr op_backend = BackendManager::GetBackend(); +#if !defined NGRAPH_TF_USE_GRAPPLER_OPTIMIZER + std::set disabled_nodes = {}; + // Find a list of nodes that are of the types that are disabled + for (auto itr : graph->nodes()) { + if (disabled_ops_set.find(itr->type_string()) != disabled_ops_set.end()) { + disabled_nodes.insert(itr->name()); + } + } + // const BuildGraphOptions options; + // cout << "trying " << options.callable_options.fetch().size() << endl; + // cout << "trying " << options.callable_options.tensor_connection().size() + // < fetch_nodes; + for (auto edge : graph->edges()) { + Node* src = edge->src(); + Node* dst = edge->dst(); + // Skip source/sink + if (dst->IsSink()) { + cout << "Skip this node " << src->type_string() << endl; + fetch_nodes.insert(src->name()); + } + } + cout << "Total nodes " << graph->num_nodes() << endl; + cout << "OP nodes " << graph->num_op_nodes() << endl; + + // nodes_to_add_identity_to = fetch_nodes - disabled_nodes + std::set nodes_to_add_identity_to; + std::set_difference(fetch_nodes.begin(), fetch_nodes.end(), + disabled_nodes.begin(), disabled_nodes.end(), + std::inserter(nodes_to_add_identity_to, + nodes_to_add_identity_to.begin())); + + // Rewrite graph to add IdentityN node so the fetch node can be encapsulated + // as well + // If the fetch node in question has 0 outputs or any of the outputs + // has ref type as a data type then don't add IdentityN node, but the fetch + // node will be skipped from marking and clustering. + TF_RETURN_IF_ERROR(AddIdentityN(graph, nodes_to_add_identity_to)); + skip_these_nodes = nodes_to_add_identity_to; +#endif + for (auto node : graph->op_nodes()) { + cout << node->type_string() << endl; bool mark_for_clustering = false; - do { // check if output node bool skip_it = false; diff --git a/ngraph_bridge/ngraph_rewrite_pass.cc b/ngraph_bridge/ngraph_rewrite_pass.cc index 80867189d..e840d8dce 100644 --- a/ngraph_bridge/ngraph_rewrite_pass.cc +++ b/ngraph_bridge/ngraph_rewrite_pass.cc @@ -78,6 +78,7 @@ mutex NGraphRewritePass::s_serial_counter_mutex; class NGraphEncapsulationPass : public NGraphRewritePass { public: Status Run(const GraphOptimizationPassOptions& options) override { + // cout << "trying " << options.sessioncallable_options.fetch().size(); // If we don't get a main graph, log that fact and bail. if (options.graph == nullptr) { NGRAPH_VLOG(0) << "NGraphEncapsulationPass: options.graph == nullptr"; @@ -151,6 +152,6 @@ class NGraphEncapsulationPass : public NGraphRewritePass { } // namespace ngraph_bridge -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 0, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0, ngraph_bridge::NGraphEncapsulationPass); } // namespace tensorflow \ No newline at end of file diff --git a/test/tests_linux_cpu.txt b/test/tests_linux_cpu.txt index 00769c15b..997cedca4 100644 --- a/test/tests_linux_cpu.txt +++ b/test/tests_linux_cpu.txt @@ -11,12 +11,12 @@ # Read in one/more external manifest file(s) # Path specified is relative to this file's path -tests_common.txt +#tests_common.txt ################################################### [RUN] # Specify tests/patterns/regex that should be included - +MathOps.Abs1D ################################################### [SKIP] # Specify tests/patterns/regex that should be excluded/skipped