-
Notifications
You must be signed in to change notification settings - Fork 0
/
msg_gan.py
268 lines (208 loc) · 10.9 KB
/
msg_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import progressbar as pb
import os
import itertools
from tensorflow.keras.models import Model
from tensorflow.keras import optimizers
import gan_layers as gl
import gan_util as util
import gan_params as gp
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], gp.gpu_grow_memory)
tf.config.experimental_run_functions_eagerly(gp.run_functions_eagerly)
plt.rcParams["figure.figsize"] = gp.figure_size
class MsgGan:
def __init__(self):
assert np.ceil(np.log2(gp.target_res)) == \
np.floor(np.log2(gp.target_res)), 'Target resolution must be a power of 2'
self.target_res = int(np.log2(gp.target_res / 2))
self.gen_opt = optimizers.Adam(learning_rate=gp.learning_rate, beta_1=gp.beta_1, beta_2=gp.beta_2)
self.crt_opt = optimizers.Adam(learning_rate=gp.learning_rate, beta_1=gp.beta_1, beta_2=gp.beta_2)
if gp.use_mixed_precision:
self.gen_opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(self.gen_opt)
self.crt_opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(self.crt_opt)
# Build models
self.generator = self.build_generator()
self.critic = self.build_critic()
# Seed for keeping track of progress
self.seed = util.generate_latents()
# Set up directories
util.init_directory(gp.tensorboard_dir) # log dir
util.init_directory(gp.sample_output_dir) # generated sample dir
util.init_directory(gp.model_dir) # model architecture plot dir
util.init_directory(gp.model_weight_dir) # saved model weights dir
# Set up TensorBoard logging
self.log_writer = tf.summary.create_file_writer(gp.tensorboard_dir)
self.total_seen = 0 # keep track of total number of images seen
self.total_batches = 0 # keep track of total batches sen
# Write model summaries to TensorBoard for debugging
self.write_model_summaries()
# Create CheckpointManager to save model state during training
self.checkpoint_mgr = self.create_checkpoint_manager()
# Progress bar format
self.pb_widgets = [
'[', pb.Variable('batches', format='{formatted_value}'), '] ', pb.Variable('g_loss', precision=10), ', ',
pb.Variable('c_loss', precision=10), ' [', pb.SimpleProgress(), '] ',
pb.FileTransferSpeed(unit='img', prefixes=['', 'k', 'm']), '\t', pb.ETA()
]
def write_model_summaries(self):
with self.log_writer.as_default():
tf.summary.text('generator', self.generator.to_json(), step=0)
tf.summary.text('critic', self.critic.to_json(), step=0)
def create_checkpoint_manager(self):
checkpoint = tf.train.Checkpoint(generator=self.generator, critic=self.critic,
gen_opt=self.gen_opt, crt_opt=self.crt_opt,
seed=self.seed, total_seen=self.total_seen, total_batches=self.total_batches)
return tf.train.CheckpointManager(checkpoint, directory=gp.model_weight_dir, max_to_keep=3)
def random_image(self, show=True):
r_img = self.generator(util.generate_latents())
self.view_imgs(r_img, show=show)
# view multi-scale generator output
def view_imgs(self, images, show=False, rows=4):
# print(np.shape(images))
assert rows <= gp.batch_size, 'Number of rows cannot exceed batch size'
fig, axs = plt.subplots(nrows=rows, ncols=self.target_res)
image_indices = itertools.product(range(self.target_res), range(rows))
for row, col in image_indices:
image = images[row][col]
image = (image - np.min(image)) / np.ptp(image) # Scale images to [0, 1]
axs[col, row].axis('off')
axs[col, row].imshow(image)
if show:
plt.show()
else:
plt.savefig(os.path.join(gp.sample_output_dir, str(self.total_seen) + '.png'))
# Clean up to avoid memory leaks
plt.cla() # close axis
plt.close(fig)
# TODO: Implement exponential moving averages for the generator weights
# TODO: Implement multi-GPU support
def train(self):
epochs = 10
dataset = util.load_celeba(self.target_res)
for epoch in range(epochs):
epoch_seen = 0
with pb.ProgressBar(widgets=self.pb_widgets, max_value=gp.images_per_epoch) as progress:
while epoch_seen < gp.images_per_epoch:
# Get a new batch of real images and create new generator input
real_batch = dataset.next()
latent_input = util.generate_latents()
critic_loss = self.train_critic(real_batch, latent_input)
generator_loss = self.train_generator(real_batch, latent_input)
epoch_seen += gp.batch_size
self.total_seen += gp.batch_size
self.total_batches += 1
# Write results to TensorBoard
with self.log_writer.as_default():
tf.summary.scalar('g_loss', generator_loss, step=self.total_seen)
tf.summary.scalar('c_loss', critic_loss, step=self.total_seen)
# TODO: Implement with callbacks
if util.time_to_update(epoch_seen):
progress.update(epoch_seen, batches=self.total_batches,
g_loss=generator_loss, c_loss=critic_loss)
if util.time_for_img(epoch_seen):
sample_imgs = self.generator(self.seed)
self.view_imgs(sample_imgs)
# Save model weights after each epoch
self.checkpoint_mgr.save()
# Hypothesis: getting the grads of the MEAN hinge loss loses batch data and destabilizes the training process
# Relativistic Hinge loss (RaHinge)
# See SAGAN paper: https://arxiv.org/pdf/1805.08318.pdf
@tf.function
def train_generator(self, real_batch, latent_input):
with tf.GradientTape() as gen_tape:
fake_batch = self.generator(latent_input)
real_preds = self.critic(real_batch)
fake_preds = self.critic(fake_batch)
real_fake_diff = real_preds - tf.reduce_mean(fake_preds)
fake_real_diff = fake_preds - tf.reduce_mean(real_preds)
loss = tf.reduce_mean(tf.nn.relu(1 + real_fake_diff)) + tf.reduce_mean(tf.nn.relu(1 - fake_real_diff))
scaled_loss = self.gen_opt.get_scaled_loss(loss)
scaled_gen_grads = gen_tape.gradient(scaled_loss, self.generator.trainable_variables)
gen_grads = self.gen_opt.get_unscaled_gradients(scaled_gen_grads)
self.gen_opt.apply_gradients(zip(gen_grads, self.generator.trainable_variables))
return loss
@tf.function
def train_critic(self, real_batch, latent_input):
with tf.GradientTape() as crt_tape:
fake_batch = self.generator(latent_input)
real_preds = self.critic(real_batch)
fake_preds = self.critic(fake_batch)
real_fake_diff = real_preds - tf.reduce_mean(fake_preds)
fake_real_diff = fake_preds - tf.reduce_mean(real_preds)
loss = tf.reduce_mean(tf.nn.relu(1 - real_fake_diff)) + tf.reduce_mean(tf.nn.relu(1 + fake_real_diff))
scaled_loss = self.crt_opt.get_scaled_loss(loss)
scaled_crt_grads = crt_tape.gradient(scaled_loss, self.critic.trainable_variables)
crt_grads = self.crt_opt.get_unscaled_gradients(scaled_crt_grads)
self.crt_opt.apply_gradients(zip(crt_grads, self.critic.trainable_variables))
return loss
# TODO: Implement fused upscale & downscaled
def build_generator(self):
# Keep track of multi-scale generator outputs to feed to critic
outputs = []
input_layer = gl.input_layer(shape=(gp.latent_dim,))
# Input block
gen = gl.dense(input_layer, 4 * 4 * util.nf(0))
gen = gl.reshape(gen, shape=(4, 4, util.nf(0)))
gen = gl.leaky_relu(gen)
gen = gl.normalize(gen, method='pixel_norm')
gen = gl.conv2d(gen, util.nf(0), kernel=3)
gen = gl.leaky_relu(gen)
gen = gl.normalize(gen, method='pixel_norm')
outputs.append(gl.to_rgb(gen))
# Add the hidden generator blocks
for block_res in range(self.target_res - 1):
gen = gl.nearest_neighbor(gen)
gen = gl.conv2d(gen, util.nf(block_res + 1), kernel=3)
gen = gl.leaky_relu(gen)
gen = gl.normalize(gen, method='pixel_norm')
gen = gl.conv2d(gen, util.nf(block_res + 1), kernel=3)
gen = gl.leaky_relu(gen)
gen = gl.normalize(gen, method='pixel_norm')
outputs.append(gl.to_rgb(gen))
# Return finalized model TODO: Compile
outputs = list(reversed(outputs)) # so that generator outputs and critic inputs are aligned
return Model(inputs=input_layer, outputs=outputs)
# TODO: This can be cleaned up a bit
def build_critic(self):
inputs = []
exp_res = util.log_to_res(self.target_res)
# Input layer (no concatenate in input layer)
inputs.append(gl.input_layer(shape=(exp_res, exp_res, 3)))
crt = gl.conv2d(inputs[-1], util.nf(self.target_res - 1), kernel=1)
if gp.mbstd_in_each_layer:
crt = gl.minibatch_std(crt)
crt = gl.conv2d(crt, util.nf(self.target_res - 1), kernel=3)
crt = gl.leaky_relu(crt)
crt = gl.conv2d(crt, util.nf(self.target_res - 2), kernel=3)
crt = gl.leaky_relu(crt)
crt = gl.avg_pool(crt)
# Intermediate layers
for res in range(self.target_res - 1, 1, -1):
exp_res = util.log_to_res(res)
inputs.append(gl.input_layer(shape=(exp_res, exp_res, 3)))
# Multi-scale critic input
crt = gl.combine(crt, inputs[-1], features=util.nf(res - 1))
if gp.mbstd_in_each_layer:
crt = gl.minibatch_std(crt)
crt = gl.conv2d(crt, util.nf(res - 1), kernel=3)
crt = gl.leaky_relu(crt)
crt = gl.conv2d(crt, util.nf(res - 2), kernel=3)
crt = gl.leaky_relu(crt)
crt = gl.avg_pool(crt)
# Output layer
inputs.append(gl.input_layer(shape=(4, 4, 3)))
crt = gl.combine(crt, inputs[-1], features=util.nf(0))
crt = gl.minibatch_std(crt)
crt = gl.conv2d(crt, util.nf(0), kernel=3)
crt = gl.leaky_relu(crt)
crt = gl.flatten(crt)
crt = gl.dense(crt, util.nf(0))
crt = gl.leaky_relu(crt)
crt = gl.dense(crt, 1, dtype='float32')
# Finalized model
return Model(inputs=inputs, outputs=crt)
msg = MsgGan()
util.pm(msg.generator, 'g' + str(msg.generator.output_shape[0][1]))
util.pm(msg.critic, 'c' + str(msg.critic.input_shape[0][1]))