forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransform.h
174 lines (155 loc) · 5.61 KB
/
transform.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#pragma once
#include "caffe2/core/common.h"
#include "caffe2/core/graph.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
/**
* The Transform Base Object
*
* A Transform is an operation which manipulates a Caffe2 NetDef.
* You can consider it as a function: Transform.ApplyTo(NetDef) -> NetDef
*
* A Transform Operation does 4 things:
* 1) Creates a Graph object from a NetDef, which stores connections.
* 2) Pattern Matches on the Graph, to find subgraphs it wants to change.
* 3) Replaces the subgraphs that it's matched with new operators.
* 4) Creates a NetDef from the changed Graph, and returns it.
*
* The effect of a Transform is defined by its 3 protected virtual functions.
* 1) PatternRule determines for an ordered subgraph and a node, whether to
* consider adding the node to the subgraph.
* 2) ValidatorRule determines, for an ordered subgraph, whether it is a
* match.
* 3) ReplaceRule mutates the graph, based on a matched subgraph.
*
* This is the base class for all derived classes to base off. To create your
* own transform, write your implementations for PatternRule, ValidatorRule, and
* ReplaceRule.
*/
class TORCH_API Transform {
public:
Transform() {}
/**
* Apply a Transform onto a NetDef.
* Returns the transformed NetDef.
*/
NetDef ApplyTo(const NetDef& orig_net_def);
virtual ~Transform() {}
/**
* Determines the type of subgraphs that PatternMatch will find.
*
* CONNECTED_SUBGRAPH will only match subgraphs that are connected.
* These subgraphs satisfy that every node of the match is connected to the
* subgraph of the nodes that come before it.
* For example, in the graph (1) --> (2) --> (3) --> (4),
* This is capable of matching the subgraph [2, 3] and [4, 3]
* This is not capable of matching the subgraph [2, 4].
*
*
* SORTED_WRT_EXECUTION_ORDER will match subgraphs that guarantee
* sorted execution order.
* The nodes don't have to be connected. It is faster than General.
* For example, in the graph (1) --> (2) --> (3) --> (4),
* This is capable of matching the subgraph [2, 4], [3, 4].
* This is not capable of matching the subgraph [3, 1], [4, 3].
*
*
* GENERAL can match any subgraph.
* For example, in the graph (1) --> (2) --> (3) --> (4),
* This is capable of matching subgraphs [2, 4], [3, 4], [4, 2, 1].
* There is no ordered subgraph of G that cannot be matched by this.
*/
enum PatternMatchType {
CONNECTED_SUBGRAPH,
SORTED_WRT_EXECUTION_ORDER,
GENERAL
};
/**
* Generates all matches (stored as ordered subgraphs) and returns them.
*
* A match is stored as vector<int>, which is a mapping to OperatorDefs
* in Graph. The order matters.
*/
std::vector<std::vector<int>> PatternMatch(const transform::Graph& graph);
/**
* Applies the replace rule onto each of the matches found.
*/
void ReplacePattern(
const std::vector<std::vector<int>>& matches,
transform::Graph* graph);
protected:
/**
* The PatternRule essentially answers:
* Given the current subgraph (ordered), should we append the new node at idx?
*/
virtual bool PatternRule(
const transform::Graph& g,
const std::vector<int>& subgraph,
int /*idx*/) {
CAFFE_NOT_IMPLEMENTED;
}
/**
* The ValidatorRule essentially answers:
* Given a subgraph, can we accept it?
*/
virtual bool ValidatorRule(
const transform::Graph& g,
const std::vector<int>& subgraph) {
CAFFE_NOT_IMPLEMENTED;
}
/**
* The ReplaceRule actually mutates the graph, and applies the transformation
* upon the subgraph.
*/
virtual bool ReplaceRule(
const std::vector<int>& subgraph,
transform::Graph* g_ptr) {
CAFFE_NOT_IMPLEMENTED;
}
void SetPatternMatchType(PatternMatchType type) {
pattern_match_type_ = type;
}
private:
/**
* A helper function for PatternMatch, which keeps track of the best subgraph
* so far.
*/
void PatternMatchHelper(
const transform::Graph& graph,
const std::vector<bool>& matched,
std::vector<int>* subgraph_ptr,
std::vector<int>* best_subgraph_ptr);
/**
* Attempts to append each neighbor to the end of the subgraph.
*/
void TryNeighbors(
const transform::Graph& graph,
const std::map<int, std::vector<string>>& neighbors,
const std::vector<bool>& matched,
std::vector<int>* subgraph_ptr,
std::vector<int>* best_subgraph_ptr);
PatternMatchType pattern_match_type_ = CONNECTED_SUBGRAPH;
};
// Creates a Transform based on a key, which should be defined in registry.
TORCH_API unique_ptr<Transform> CreateTransform(string key);
C10_DECLARE_REGISTRY(TransformRegistry, Transform);
#define REGISTER_TRANSFORM(name, ...) \
C10_REGISTER_CLASS(TransformRegistry, name, __VA_ARGS__)
// Create a Transform object from registry,
// and immediately apply it to a Netdef.
TORCH_API NetDef ApplyTransform(const string& key, const NetDef& netdef);
// Create a Transform object from registry, apply it to a NetDef.
// Will only return the transformed net if it is faster than the old net.
// This will run the init net first, will run the two nets warmup_runs times.
// Then, we will take the average time of main_runs runs, and only keep the
// transformed net if it is faster by a factor of improvement_threshold.
TORCH_API NetDef ApplyTransformIfFaster(
const string& key,
const NetDef& netdef,
const NetDef& init_netdef,
const int warmup_runs,
const int main_runs,
const double improvement_threshold);
} // namespace