-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlog_utils.py
86 lines (64 loc) · 2.85 KB
/
log_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import wandb
import numpy as np
import equinox as eqx
from jax.nn import sigmoid
from einops import rearrange
from jax.random import split, randint
from jax import vmap
from PIL import Image
from models import AutoEncoder, DoublingVNCA, NonDoublingVNCA
# typing
from jax import Array
from jax.random import PRNGKeyArray
def save_model(model, step):
'''Saves the model to wandb'''
model_file_name = f'{model.__class__.__name__}_gstep{step}.eqx'
eqx.tree_serialise_leaves(model_file_name, model)
wandb.save(model_file_name)
def restore_model(model_like, file_name, run_path=None):
'''Restores the model from wandb, given a model_like object and the file name'''
wandb.restore(file_name, run_path=run_path)
model = eqx.tree_deserialise_leaves(file_name, model_like)
return model
def to_wandb_img(x: Array) -> wandb.Image:
'''Converts an array of shape (c, h, w) to a wandb Image'''
return wandb.Image(to_PIL_img(x))
def to_PIL_img(x: Array) -> Image:
'''Converts an array of shape (c, h, w) to a PIL Image'''
x = np.clip(x, 0, 1)
return Image.fromarray(np.array(255 * x, dtype=np.uint8)[0])
@eqx.filter_jit
def to_grid(x: Array, ih: int, iw: int) -> Array:
'''Rearranges a array of images with shape (n, c, h, w) to a grid of shape (c, ih*h, iw*w)'''
return rearrange(x, '(ih iw) c h w -> c (ih h) (iw w)', ih=ih, iw=iw)
@eqx.filter_jit
def log_center(model: AutoEncoder) -> Array:
'''Returns the center of the latent space'''
return sigmoid(model.center())
@eqx.filter_jit
def log_samples(model: AutoEncoder, ih: int = 4, iw: int = 8, *, key: PRNGKeyArray) -> Array:
'''Returns a grid of samples from the model'''
keys = split(key, ih * iw)
samples = vmap(model.sample)(key=keys)
samples = sigmoid(samples)
return to_grid(samples, ih=ih, iw=iw)
@eqx.filter_jit
def log_reconstructions(model: AutoEncoder, data: Array, ih: int = 4, iw: int = 8, *, key: PRNGKeyArray) -> Array:
'''Returns a grid of reconstructions from the model'''
idx = randint(key, (ih * iw,), 0, len(data))
keys = split(key, ih * iw)
x = data[idx]
reconstructions, _, _, _ = vmap(model)(x, key=keys)
reconstructions = rearrange(reconstructions, 'n m c h w -> (n m) c h w')
reconstructions = sigmoid(reconstructions)
return to_grid(reconstructions, ih=ih, iw=iw)
@eqx.filter_jit
def log_growth_stages(model: DoublingVNCA, *, key: PRNGKeyArray) -> Array:
'''Returns a grid of growth stages from the DoublingVNCA model'''
stages = model.growth_stages(key=key)
return to_grid(stages, ih=model.K, iw=model.N_nca_steps + 1)
@eqx.filter_jit
def log_nca_stages(model: NonDoublingVNCA, ih: int = 4, iw: int = 9, *, key: PRNGKeyArray) -> Array:
'''Returns a grid of NCA stages from the NonDoublingVNCA model'''
stages = model.nca_stages(T=ih * iw, key=key)
return to_grid(stages, ih=ih, iw=iw)