2626#include < iostream>
2727#include " tensorflow/core/graph/graph.h"
2828
29+ #include " ngraph/ngraph.hpp"
30+
31+ #include " ngraph_bridge/ngraph_partial_shapes.h"
32+
2933namespace tensorflow {
3034
3135namespace ngraph_bridge {
@@ -35,10 +39,110 @@ typedef std::map<std::string, std::vector<int>> ShapeHintMap;
3539// the integer represent AOT level requested.
3640typedef std::pair<bool , std::set<ShapeHintMap>> AOTInfo;
3741
42+ // TODO: an optimization would be to separate the analysis and rewriting passes
43+ // cleanly, so that analysis pass is run in mark_for_clustering, and its
44+ // information is reused here instead of recalculating
45+ // To do that an Encapsulator object with AnalysisPass run can be created in
46+ // MarkForClustering, and that can be passed to EncapsulateClusters
47+
48+ // / Takes a TF graph where ngraph_cluster attributes has been marked in a
49+ // / preceeding pass (assign_clusters), then replaces TF subgraphs and inserts
50+ // / encapsulate ops in their place. Optionally can perform ahead of time
51+ // / compilation.
3852Status EncapsulateClusters (
3953 Graph* graph, int graph_id, FunctionDefLibrary* fdeflib,
40- std::unordered_map<std::string, std::string> device_config,
41- AOTInfo aot_info);
54+ const std::unordered_map<std::string, std::string>& device_config,
55+ const AOTInfo& aot_info);
56+
57+ // TODO Encapsulator is dependent on ClusterManager. They could be made
58+ // independent.
59+
60+ // A class to perform analysis (identify subgraphs)
61+ // and rewriting (create encapsulates and splice them in)
62+ // Order of calling: construction -> AnalysisPass -> RewritePass
63+ // |
64+ // v
65+ // NewClusterIds
66+ // Any other order of calling will generate errors
67+ // Cannot be copied/moved or reset
68+ class Encapsulator {
69+ public:
70+ Encapsulator (Graph* g);
71+ // Populate ClusterManager with the subgraphs for each potential encapsulate
72+ Status AnalysisPass ();
73+ // Perform the actual graph surgery
74+ Status RewritePass (
75+ FunctionDefLibrary* fdeflib, int graph_id,
76+ const std::unordered_map<std::string, std::string>& device_config);
77+ // Returns the newly created cluster ids after AnalysisPass is done
78+ // Needed because ClusterManager (CM) might have contained old stuff,
79+ // so it might not be possible to query the CM itself to get this
80+ Status GetNewClusterIDs (std::set<int >& result);
81+
82+ Encapsulator (const Encapsulator&) = delete ;
83+ Encapsulator (Encapsulator&&) = delete ;
84+ Encapsulator& operator =(const Encapsulator&) = delete ;
85+ Encapsulator& operator =(Encapsulator&&) = delete ;
86+
87+ private:
88+ Graph* graph;
89+ // boolean to indicate if analysis has been done
90+ // If not rewritepass should not be called
91+ bool analysis_done;
92+ // boolean to indicate that rewrite is done;
93+ bool rewrite_done;
94+ // A map from cluster indices to the expected device name for nodes
95+ // in that cluster.
96+ std::map<int , std::string> device_name_map;
97+
98+ // We *should* eventually have a way of monitoring the device and the backend
99+ // together
100+ std::map<int , std::string> backend_name_map;
101+
102+ // As we build the graph we will be tracking the.. TODO(amprocte): finish
103+ // this comment.
104+ std::map<std::tuple<int , int >, std::tuple<int , int >> output_remap_map;
105+ std::map<std::tuple<int , int , int >, int > input_remap_map;
106+ std::map<std::tuple<int , std::string, int >, string> input_rename_map;
107+
108+ // A map from cluster indices to a vector of input data types.
109+ std::map<int , std::vector<std::tuple<int , int , DataType>>> cluster_input_map;
110+ // A map from cluster indices to a vector of output data types.
111+ std::map<int , std::vector<DataType>> cluster_output_dt_map;
112+
113+ // A map from cluster indices to corresponding NGraphEncapsulate nodes.
114+ std::map<int , Node*> cluster_node_map;
115+
116+ std::set<int > cluster_indices_for_this_graph;
117+
118+ static void AddInput (NodeDef* dst, StringPiece src_name, int src_slot);
119+ };
120+
121+ // Translates TF subgraph to ng function then compiles it
122+ Status PerformAOTOnEncapsulates (Graph* graph, const AOTInfo& aot_info);
123+
124+ std::string HintAsString (ShapeHintMap single_hint);
125+
126+ // Given a node, partialshape info from TF (present in the .pb itself) and a
127+ // shape hint, combine all that information
128+ PartialShape CombineNodeInfoAndHint (Node* node,
129+ PartialShape partial_shape_from_node,
130+ const ShapeHintMap& single_hint);
131+
132+ // Given a TF graph, it scans it for inputs and finds what TF is saying about
133+ // their shapes (in the .pb itself)
134+ // Creates a map between input node names and PartialShape information we get
135+ // from the TF graph
136+ std::map<std::string, PartialShape> GetShapesFromTFInputnodes (
137+ Graph* graph, const string& input_node_type);
138+
139+ // Given an encapsulate node, and the input shapes,
140+ // performs TranslateGraph and returns an ng function and a signature
141+ Status PerformTranslation (Node* node,
142+ const std::map<std::string, std::vector<int >>&
143+ inputs_node_shapes_for_compilation,
144+ std::string& signature,
145+ std::shared_ptr<ngraph::Function>& ng_function);
42146
43147} // namespace ngraph_bridge
44148} // namespace tensorflow
0 commit comments