-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodels.py
340 lines (263 loc) · 11.9 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
from jax import numpy as np
from jax import lax, vmap, checkpoint
from jax.random import split, normal, bernoulli, randint
from jax.nn import elu, sigmoid
from einops import rearrange, repeat
from equinox import Module
from equinox.nn import Sequential, Conv2d, ConvTranspose2d, Linear, Lambda
from functools import partial
# typing
from jax import Array
from typing import Optional, Sequence, Tuple, Any
from jax.random import PRNGKeyArray
def sample_gaussian(mu: Array, logvar: Array, shape: Sequence[int], *, key: PRNGKeyArray) -> Array:
std: Array = np.exp(0.5 * logvar)
# use the reparameterization trick
return mu + std * normal(key=key, shape=shape)
def sample_bernoulli(logits: Array, shape: Sequence[int], *, key: PRNGKeyArray) -> Array:
p = sigmoid(logits)
return bernoulli(key=key, p=p, shape=shape)
def crop(x: Array, shape: Tuple[int, int, int]) -> Array:
'''Crop an image to a given size.'''
c, h, w = shape
ch, cw = x.shape[-2:]
hh, ww = (h - ch) // 2, (w - cw) // 2
return x[:c, hh : h - hh, ww : w - ww]
def pad(x: Array, p: int, c: float) -> Array:
'''Pad an image of shape (c, h, w) with contant c.'''
return np.pad(x, ((0, 0), (p, p), (p, p)), mode='constant', constant_values=c)
def flatten(x: Array) -> Array:
return rearrange(x, 'c h w -> (c h w)')
def damage(x: Array, *, key: PRNGKeyArray) -> Array:
'''Set the cell states of a H//2 x W//2 square to zero.'''
l, h, w = x.shape
h_half, w_half = h // 2, w // 2
hmask, wmask = randint(
key=key,
shape=(2,),
minval=np.zeros(2, dtype=np.int32),
maxval=np.array([h_half, w_half], dtype=np.int32),
)
update = np.zeros((l, h_half, w_half), dtype=np.float32)
return lax.dynamic_update_slice(x, update, (0, hmask, wmask))
def double(x: Array) -> Array:
return repeat(x, 'c h w -> c (h 2) (w 2)')
Flatten: Lambda = Lambda(flatten)
Sigmoid: Lambda = Lambda(sigmoid)
Double: Lambda = Lambda(double)
Elu: Lambda = Lambda(elu)
class Encoder(Sequential):
def __init__(self, latent_size: int, *, key: PRNGKeyArray):
keys = split(key, 6)
super().__init__(
[
Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), key=keys[0]),
Elu,
Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), key=keys[1]),
Elu,
Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), key=keys[2]),
Elu,
Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), key=keys[3]),
Elu,
Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), key=keys[4]),
Elu,
Flatten,
Linear(in_features=2048, out_features=2 * latent_size, key=keys[5]),
Lambda(partial(rearrange, pattern='(p l) -> p l', p=2, l=latent_size)),
]
)
class LinearDecoder(Linear):
def __init__(self, latent_size: int, *, key: PRNGKeyArray):
super().__init__(in_features=latent_size, out_features=2048, key=key)
class ConvolutionalDecoder(Sequential):
def __init__(self, *, key: PRNGKeyArray):
keys = split(key, 5)
super().__init__(
[
ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), key=keys[0]),
Elu,
ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), key=keys[1]),
Elu,
ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), key=keys[2]),
Elu,
ConvTranspose2d(64, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), key=keys[3]),
Elu,
ConvTranspose2d(32, 1, kernel_size=(5, 5), stride=(1, 1), padding=(4, 4), key=keys[4]),
]
)
class BaselineDecoder(Sequential):
def __init__(self, latent_size: int, *, key: PRNGKeyArray):
key1, key2 = split(key, 2)
super().__init__(
[
LinearDecoder(latent_size, key=key1),
Lambda(partial(rearrange, pattern='(c h w) -> c h w', h=2, w=2, c=512)), # reshape from 2048 to 512x2x2
ConvolutionalDecoder(key=key2),
Lambda(partial(pad, p=2, c=float('-inf'))), # pad from 28x28 to 32x32
]
)
class Residual(Sequential):
def __init__(self, latent_size: int, *, key: PRNGKeyArray) -> None:
key1, key2 = split(key, 2)
super().__init__(
[
Conv2d(latent_size, latent_size, kernel_size=(1, 1), stride=(1, 1), key=key1),
Elu,
Conv2d(latent_size, latent_size, kernel_size=(1, 1), stride=(1, 1), key=key2),
]
)
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
return x + super().__call__(x, key=key)
class Conv2dZeroInit(Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weight = np.zeros_like(self.weight)
self.bias = np.zeros_like(self.bias)
class NCAStep(Sequential):
def __init__(self, latent_size: int, *, key: PRNGKeyArray) -> None:
keys = split(key, 6)
super().__init__(
[
Conv2d(latent_size, latent_size, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), key=keys[0]),
Residual(latent_size=latent_size, key=keys[1]),
Residual(latent_size=latent_size, key=keys[2]),
Residual(latent_size=latent_size, key=keys[3]),
Residual(latent_size=latent_size, key=keys[4]),
Conv2dZeroInit(latent_size, latent_size, kernel_size=(1, 1), stride=(1, 1), key=keys[5]),
]
)
class NCAStepSimple(Sequential):
def __init__(self, latent_size: int, *, key: PRNGKeyArray) -> None:
key1, key2 = split(key)
super().__init__(
[
Conv2d(latent_size, latent_size, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), key=key1),
Elu,
Conv2dZeroInit(latent_size, latent_size, kernel_size=(1, 1), stride=(1, 1), key=key2),
]
)
class AutoEncoder(Module):
latent_size: int
def __call__(self, x: Array, *, key: PRNGKeyArray, M: int = 1) -> Tuple[Array, Array, Array, Array]:
# get parameters for the latent distribution
mean, logvar = self.encoder(x)
# sample from the latent distribution M times
z = sample_gaussian(mean, logvar, (M, *mean.shape), key=key)
# vmap over the M samples and reconstruct the M images
x_hat = vmap(self.decoder)(z)
# vmap over the M samples and crop the images to the original size
x_hat = vmap(partial(crop, shape=x.shape))(x_hat)
return x_hat, z, mean, logvar
def encoder(self, x: Array) -> Array:
raise NotImplementedError
def decoder(self, z: Array) -> Array:
raise NotImplementedError
def center(self) -> Array:
z_center = np.zeros((self.latent_size))
return self.decoder(z_center)
def sample(self, *, key: PRNGKeyArray) -> Array:
mean = np.zeros(self.latent_size)
logvar = np.zeros(self.latent_size)
z = sample_gaussian(mean, logvar, shape=(self.latent_size,), key=key)
return self.decoder(z)
class BaselineVAE(AutoEncoder):
encoder: Encoder
decoder: Sequential
latent_size: int
def __init__(self, latent_size: int = 256, *, key: PRNGKeyArray) -> None:
key1, key2 = split(key, 2)
self.latent_size = latent_size
self.encoder = Encoder(latent_size=latent_size, key=key1)
self.decoder = BaselineDecoder(latent_size=latent_size, key=key2)
class DoublingVNCA(AutoEncoder):
encoder: Encoder
step: NCAStep
double: Lambda
latent_size: int
K: int
N_nca_steps: int
def __init__(self, latent_size: int = 256, K: int = 5, N_nca_steps: int = 8, *, key: PRNGKeyArray) -> None:
key1, key2 = split(key)
self.latent_size = latent_size
self.encoder = Encoder(latent_size=latent_size, key=key1)
self.step = NCAStep(latent_size=latent_size, key=key2)
self.double = Double
self.K = K
self.N_nca_steps = N_nca_steps
def decoder(self, z: Array) -> Array:
# Add height and width dimensions
z = rearrange(z, 'c -> c 1 1')
# Apply the Doubling and NCA steps
for _ in range(self.K):
z = self.double(z)
for _ in range(self.N_nca_steps):
z = z + self.step(z)
return z
def growth_stages(self, n_channels: int = 1, *, key: PRNGKeyArray) -> Array:
mean = np.zeros(self.latent_size)
logvar = np.zeros(self.latent_size)
z = sample_gaussian(mean, logvar, (self.latent_size,), key=key)
# Add height and width dimensions
z = rearrange(z, 'c -> c 1 1')
def process(z: Array) -> Array:
'''Process a latent sample by taking the image channels, applying sigmoid and padding.'''
logits = z[:n_channels]
probs = sigmoid(logits)
pad_size = ((2 ** (self.K)) - probs.shape[1]) // 2
return pad(probs, p=pad_size, c=0.0)
# Decode the latent sample and save the processed image channels
stages_probs = []
for _ in range(self.K):
z = self.double(z)
stages_probs.append(process(z))
for _ in range(self.N_nca_steps):
z = z + self.step(z)
stages_probs.append(process(z))
return np.array(stages_probs)
class NonDoublingVNCA(AutoEncoder):
encoder: Encoder
step: NCAStepSimple
latent_size: int
N_nca_steps: int
N_nca_steps_min: int
N_nca_steps_max: int
def __init__(self, latent_size: int = 256, N_nca_steps: int = 36, N_nca_steps_min: int = 32, N_nca_steps_max: int = 64, *, key: PRNGKeyArray) -> None:
key1, key2 = split(key)
self.encoder = Encoder(latent_size=latent_size, key=key1)
self.step = NCAStepSimple(latent_size=latent_size, key=key2)
self.latent_size = latent_size
self.N_nca_steps = N_nca_steps
self.N_nca_steps_min = N_nca_steps_min
self.N_nca_steps_max = N_nca_steps_max
def decoder(self, z: Array) -> Array:
# repeat the latent sample over the image dimensions
z = repeat(z, 'c -> c h w', h=32, w=32)
# decode the latent sample by applying the NCA steps
return self.decode_grid(z, T=self.N_nca_steps)
def decode_grid_random(self, z: Array, *, key: PRNGKeyArray) -> Array:
T = randint(key, shape=(1,), minval=self.N_nca_steps_min, maxval=self.N_nca_steps_max)
return self.decode_grid(z, T=T)
def decode_grid(self, z: Array, T: Array) -> Array:
# Apply the NCA steps
true_fun = lambda z: z + self.step(z)
false_fun = lambda z: z
@checkpoint
def scan_fn(z: Array, t: int) -> Tuple[Array, Any]:
z = lax.cond(t, true_fun, false_fun, z)
return z, None
z, _ = lax.scan(scan_fn, z, (np.arange(self.N_nca_steps_max) < T))
return z
def nca_stages(self, n_channels: int = 1, T: int = 36, damage_idx: set = set(), *, key: PRNGKeyArray) -> Array:
mean = np.zeros(self.latent_size)
logvar = np.zeros(self.latent_size)
z = sample_gaussian(mean, logvar, (self.latent_size,), key=key)
z = repeat(z, 'c -> c h w', h=32, w=32)
# Decode the latent sample and save the processed image channels
stages_probs = []
for i in range(T):
z = z + self.step(z)
if i in damage_idx:
key, damage_key = split(key)
z = damage(z, key=damage_key)
stages_probs.append(sigmoid(z[:n_channels]))
return np.array(stages_probs)