-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheasynn.py
145 lines (108 loc) · 3.78 KB
/
easynn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class Expr:
next_id = 0
def __init__(self, op, inputs):
self.op = op
self.inputs = inputs
self.id = Expr.next_id
Expr.next_id += 1
if not isinstance(op, Op):
raise Exception("Not an operator: %s" % op)
def __dfs_post(self, ids, visitor):
ids[self.id] = True
for expr in self.inputs:
if expr.id in ids:
continue
expr.__dfs_post(ids, visitor)
visitor(self)
def statements(self):
lines = []
self.__dfs_post({}, lambda that: lines.append("%s" % that))
return "\n".join(lines)
def __str__(self):
args = ",".join(["t%d" % expr.id for expr in self.inputs])
return "t%d = %s(%s)" % (self.id, self.op, args)
def __promote(r):
if isinstance(r, Expr):
return r
else:
return Const(r)
def __add__(self, r):
return Op("", "Add", 2, {})(self, Expr.__promote(r))
def __sub__(self, r):
return Op("", "Sub", 2, {})(self, Expr.__promote(r))
def __mul__(self, r):
return Op("", "Mul", 2, {})(self, Expr.__promote(r))
def __neg__(self):
return Op("", "Neg", 1, {})(self)
def compile(self, builder):
self.__dfs_post({}, lambda that: builder.append(that))
return builder.build()
def resolve(self, parameters):
self.__dfs_post({}, lambda that: that.op.resolve(parameters))
return self
class Op:
def __init__(self, name, op_type, num_args, parameters):
self.name = name
self.op_type = op_type
self.num_args = num_args
self.parameters = parameters
def __call__(self, *inputs):
if self.num_args >= 0 and self.num_args != len(inputs):
raise Exception("%s: need %d arguments but found %d" % (self, self.num_args, len(inputs)))
for i, expr in enumerate(inputs):
if not isinstance(expr, Expr):
raise Exception("%s: arg %d is not an expression: %s" % (self, i, expr))
return Expr(self, inputs)
def __str__(self):
name = "%s.%s" % (self.name, self.op_type)
if len(self.parameters) == 0:
return name
params = ",".join(["%s=%s" % (k, v.shape if hasattr(v, "shape") else v) for k, v in self.parameters.items()])
return "%s[%s]" % (name, params)
def resolve(self, parameters):
if self.name == "":
return
for k, v in parameters.items():
if k.startswith(self.name+"."):
self.parameters[k[len(self.name)+1:]] = v
def Const(c):
if isinstance(c, (int, float)):
c = float(c)
elif hasattr(c, "shape"):
c = c.astype(float)
else:
raise Exception("Const must be float or int or ndarray: %s" % c)
return Expr(Op("", "Const", 0, {
"value": c
}), [])
def Input(n):
return Expr(Op(n, "Input", 0, {}), [])
def Input2d(n, h, w, ic):
return Expr(Op(n, "Input2d", 0, {
"height": h,
"width": w,
"in_channels": ic
}), [])
def MaxPool2d(k, s):
return Op("", "MaxPool2d", 1, {
"kernel_size": k,
"stride": s
})
def ReLU():
return Op("", "ReLU", 1, {})
def Flatten():
return Op("", "Flatten", 1, {})
def Conv2d(n, ic, oc, k, p = 0):
return Op(n, "Conv2d", 1, {
"in_channels": ic,
"out_channels": oc,
"kernel_size": k,
"padding": p
})
def Linear(n, i, o):
return Op(n, "Linear", 1, {
"in_features": i,
"out_features": o
})
def Show():
return Op("", "Show", 1, {})