12
12
13
13
class HiCGAN ():
14
14
def __init__ (self , log_dir : str ,
15
- lambda_pixel : float = 1e-5 ,
16
- loss_type_pixel : str = "L2" ,
15
+ lambda_pixel : float = 100 ,
16
+ lambda_disc : float = 0.5 ,
17
+ loss_type_pixel : str = "L1" ,
17
18
tv_weight : float = 1e-10 ,
18
- input_size : int = 256 ):
19
+ input_size : int = 256 ,
20
+ plot_frequency : int = 20 ,
21
+ learning_rate : float = 2e-5 ,
22
+ adam_beta_1 : float = 0.5 ):
19
23
super ().__init__ ()
20
24
21
25
self .OUTPUT_CHANNELS = 1
@@ -24,15 +28,15 @@ def __init__(self, log_dir: str,
24
28
if input_size in [64 ,128 ,256 ]:
25
29
self .INPUT_SIZE = input_size
26
30
self .NR_FACTORS = 14
27
- self .LAMBDA = lambda_pixel
31
+ self .lambda_pixel = lambda_pixel
32
+ self .lambda_disc = lambda_disc
28
33
self .tv_loss_Weight = tv_weight
29
34
self .loss_type_pixel = loss_type_pixel
30
35
self .loss_object = tf .keras .losses .BinaryCrossentropy (from_logits = True )
31
- self .generator_optimizer = tf .keras .optimizers .Adam (2e-5 , beta_1 = 0.5 )
32
- self .discriminator_optimizer = tf .keras .optimizers .Adam (2e-5 , beta_1 = 0.5 )
36
+ self .generator_optimizer = tf .keras .optimizers .Adam (learning_rate = learning_rate , beta_1 = adam_beta_1 , name = "Adam_Generator" )
37
+ self .discriminator_optimizer = tf .keras .optimizers .Adam (learning_rate = learning_rate , beta_1 = adam_beta_1 , name = "Adam_Discriminator" )
33
38
34
39
self .generator_intro_model = self .oneD_twoD_conversion ()
35
- self .discriminator_intro_model = self .oneD_twoD_conversion ()
36
40
self .generator = self .Generator ()
37
41
self .discriminator = self .Discriminator ()
38
42
@@ -46,8 +50,10 @@ def __init__(self, log_dir: str,
46
50
discriminator = self .discriminator )
47
51
48
52
self .progress_plot_name = os .path .join (self .log_dir , "lossOverEpochs.png" )
53
+ self .progress_plot_frequency = plot_frequency
54
+ self .example_plot_frequency = 5
49
55
50
- def oneD_twoD_conversion (self , nr_filters_list = [16 , 16 , 32 , 32 , 64 ], kernel_width_list = [4 ,4 ,4 ,4 ,4 ], nr_neurons_List = [ 5000 , 4000 , 3000 ] ):
56
+ def oneD_twoD_conversion (self , nr_filters_list = [1024 , 512 , 512 , 256 , 256 , 128 , 128 , 64 ], kernel_width_list = [4 ,4 ,4 ,4 ,4 , 4 , 4 , 4 ], apply_dropout : bool = False ):
51
57
inputs = tf .keras .layers .Input (shape = (3 * self .INPUT_SIZE , self .NR_FACTORS ))
52
58
#add 1D convolutions
53
59
x = inputs
@@ -57,17 +63,26 @@ def oneD_twoD_conversion(self, nr_filters_list=[16,16,32,32,64], kernel_width_li
57
63
convParamDict ["filters" ] = nr_filters
58
64
convParamDict ["kernel_size" ] = kernelWidth
59
65
convParamDict ["data_format" ]= "channels_last"
66
+ convParamDict ["kernel_regularizer" ]= tf .keras .regularizers .l2 (0.01 )
60
67
if kernelWidth > 1 :
61
68
convParamDict ["padding" ] = "same"
62
69
x = Conv1D (** convParamDict )(x )
63
70
x = BatchNormalization ()(x )
64
- x = tf .keras .layers .Activation ("sigmoid" )(x )
65
- #make the shape of a 2D-image
66
- x = Conv1D (filters = self .INPUT_SIZE , strides = 3 , kernel_size = 4 , data_format = "channels_last" , activation = "sigmoid" , padding = "same" , name = "conv1D_final" )(x )
67
- y = tf .keras .layers .Permute ((2 ,1 ))(x )
68
- diag = tf .keras .layers .Lambda (lambda z : - 1 * tf .linalg .band_part (z , 0 , 0 ))(x )
69
- x = tf .keras .layers .Add ()([x , y , diag ])
70
-
71
+ if apply_dropout :
72
+ x = Dropout (0.5 )(x )
73
+ x = tf .keras .layers .LeakyReLU (alpha = 0.2 )(x )
74
+ #make the shape of a square matrix
75
+ x = Conv1D (filters = self .INPUT_SIZE ,
76
+ strides = 3 ,
77
+ kernel_size = 4 ,
78
+ data_format = "channels_last" ,
79
+ activation = "sigmoid" ,
80
+ padding = "same" , name = "conv1D_final" )(x )
81
+ #ensure the matrix is symmetric, i.e. x = transpose(x)
82
+ x_T = tf .keras .layers .Permute ((2 ,1 ))(x ) #this is the matrix transpose
83
+ x = tf .keras .layers .Add ()([x , x_T ])
84
+ x = tf .keras .layers .Lambda (lambda z : 0.5 * z )(x ) #add transpose and divide by 2
85
+ #reshape the matrix into a 2D grayscale image
71
86
x = tf .keras .layers .Reshape ((self .INPUT_SIZE ,self .INPUT_SIZE ,self .INPUT_CHANNELS ))(x )
72
87
model = tf .keras .Model (inputs = inputs , outputs = x , name = "crazy_intro_model" )
73
88
#model.build(input_shape=(3*self.INPUT_SIZE, self.NR_FACTORS))
@@ -141,8 +156,7 @@ def Generator(self):
141
156
last = tf .keras .layers .Conv2DTranspose (self .OUTPUT_CHANNELS , 4 ,
142
157
strides = 2 ,
143
158
padding = 'same' ,
144
- kernel_initializer = initializer ,
145
- activation = 'sigmoid' ) # (bs, 256, 256, 3)
159
+ kernel_initializer = initializer ) # (bs, 256, 256, 3)
146
160
147
161
x = inputs
148
162
x = twoD_conversion (x )
@@ -161,8 +175,11 @@ def Generator(self):
161
175
x = tf .keras .layers .Concatenate ()([x , skip ])
162
176
163
177
x = last (x )
178
+ #enforce symmetry
164
179
x_T = tf .keras .layers .Permute ((2 ,1 ,3 ))(x )
165
180
x = tf .keras .layers .Add ()([x , x_T ])
181
+ x = tf .keras .layers .Lambda (lambda z : 0.5 * z )(x )
182
+ x = tf .keras .layers .Activation ("sigmoid" )(x )
166
183
167
184
return tf .keras .Model (inputs = inputs , outputs = x )
168
185
@@ -175,7 +192,7 @@ def generator_loss(self, disc_generated_output, gen_output, target):
175
192
else :
176
193
pixel_loss = tf .reduce_mean (tf .square (target - gen_output ))
177
194
tv_loss = tf .reduce_mean (tf .image .total_variation (gen_output ))
178
- total_gen_loss = pixel_loss + 1 / self .LAMBDA * gan_loss + self .tv_loss_Weight * tv_loss
195
+ total_gen_loss = self . lambda_pixel * pixel_loss + self .lambda_disc * gan_loss + self .tv_loss_Weight * tv_loss
179
196
return total_gen_loss , gan_loss , pixel_loss
180
197
181
198
@@ -184,7 +201,7 @@ def Discriminator(self):
184
201
185
202
inp = tf .keras .layers .Input (shape = [3 * self .INPUT_SIZE , self .NR_FACTORS ], name = 'input_image' )
186
203
tar = tf .keras .layers .Input (shape = [self .INPUT_SIZE , self .INPUT_SIZE , self .OUTPUT_CHANNELS ], name = 'target_image' )
187
- twoD_conversion = self .discriminator_intro_model
204
+ twoD_conversion = self .oneD_twoD_conversion ()
188
205
#x = Flatten()(inp)
189
206
#x = Dense(units = self.INPUT_SIZE*(self.INPUT_SIZE+1)//2)(x)
190
207
#x = tf.keras.layers.LeakyReLU()(x)
@@ -198,24 +215,46 @@ def Discriminator(self):
198
215
d = twoD_conversion (inp )
199
216
d = tf .keras .layers .Concatenate ()([d , tar ])
200
217
if self .INPUT_SIZE > 64 :
218
+ #downsample and symmetrize 1
201
219
d = HiCGAN .downsample (64 , 4 , False )(d ) # (bs, inp.size/2, inp.size/2, 64)
220
+ d_T = tf .keras .layers .Permute ((2 ,1 ,3 ))(d )
221
+ d = tf .keras .layers .Add ()([d , d_T ])
222
+ d = tf .keras .layers .Lambda (lambda z : 0.5 * z )(d )
223
+ #downsample and symmetrize 2
202
224
d = HiCGAN .downsample (128 , 4 )(d )# (bs, inp.size/4, inp.size/4, 128)
225
+ d_T = tf .keras .layers .Permute ((2 ,1 ,3 ))(d )
226
+ d = tf .keras .layers .Add ()([d , d_T ])
227
+ d = tf .keras .layers .Lambda (lambda z : 0.5 * z )(d )
203
228
else :
229
+ #downsample and symmetrize 3
204
230
d = HiCGAN .downsample (256 , 4 )(d )
231
+ d_T = tf .keras .layers .Permute ((2 ,1 ,3 ))(d )
232
+ d = tf .keras .layers .Add ()([d , d_T ])
233
+ d = tf .keras .layers .Lambda (lambda z : 0.5 * z )(d )
234
+ #downsample and symmetrize 4
205
235
d = HiCGAN .downsample (256 , 4 )(d ) # (bs, inp.size/8, inp.size/8, 256)
236
+ d_T = tf .keras .layers .Permute ((2 ,1 ,3 ))(d )
237
+ d = tf .keras .layers .Add ()([d , d_T ])
238
+ d = tf .keras .layers .Lambda (lambda z : 0.5 * z )(d )
206
239
d = Conv2D (512 , 4 , strides = 1 , padding = "same" , kernel_initializer = initializer )(d ) #(bs, inp.size/8, inp.size/8, 512)
240
+ d_T = tf .keras .layers .Permute ((2 ,1 ,3 ))(d )
241
+ d = tf .keras .layers .Add ()([d , d_T ])
242
+ d = tf .keras .layers .Lambda (lambda z : 0.5 * z )(d )
207
243
d = BatchNormalization ()(d )
208
244
d = LeakyReLU (alpha = 0.2 )(d )
209
245
d = Conv2D (1 , 4 , strides = 1 , padding = "same" ,
210
246
kernel_initializer = initializer )(d ) #(bs, inp.size/8, inp.size/8, 1)
247
+ d_T = tf .keras .layers .Permute ((2 ,1 ,3 ))(d )
248
+ d = tf .keras .layers .Add ()([d , d_T ])
249
+ d = tf .keras .layers .Lambda (lambda z : 0.5 * z )(d )
211
250
d = tf .keras .layers .Activation ("sigmoid" )(d )
212
251
return tf .keras .Model (inputs = [inp , tar ], outputs = d )
213
252
214
253
def discriminator_loss (self , disc_real_output , disc_generated_output ):
215
254
real_loss = self .loss_object (tf .ones_like (disc_real_output ), disc_real_output )
216
255
generated_loss = self .loss_object (tf .zeros_like (disc_generated_output ), disc_generated_output )
217
256
total_disc_loss = real_loss + generated_loss
218
- return total_disc_loss
257
+ return total_disc_loss , real_loss , generated_loss
219
258
220
259
221
260
@tf .function
@@ -226,8 +265,8 @@ def train_step(self, input_image, target, epoch):
226
265
disc_real_output = self .discriminator ([input_image , target ], training = True )
227
266
disc_generated_output = self .discriminator ([input_image , gen_output ], training = True )
228
267
229
- gen_total_loss , gen_gan_loss , gen_l1_loss = self .generator_loss (disc_generated_output , gen_output , target )
230
- disc_loss = self .discriminator_loss (disc_real_output , disc_generated_output )
268
+ gen_total_loss , _ , _ = self .generator_loss (disc_generated_output , gen_output , target )
269
+ disc_loss , disc_real_loss , disc_gen_loss = self .discriminator_loss (disc_real_output , disc_generated_output )
231
270
232
271
generator_gradients = gen_tape .gradient (gen_total_loss ,
233
272
self .generator .trainable_variables )
@@ -239,7 +278,7 @@ def train_step(self, input_image, target, epoch):
239
278
self .discriminator_optimizer .apply_gradients (zip (discriminator_gradients ,
240
279
self .discriminator .trainable_variables ))
241
280
242
- return gen_total_loss , disc_loss
281
+ return gen_total_loss , disc_loss , disc_real_loss , disc_gen_loss
243
282
244
283
@tf .function
245
284
def validationStep (self , input_image , target , epoch ):
@@ -248,8 +287,8 @@ def validationStep(self, input_image, target, epoch):
248
287
disc_real_output = self .discriminator ([input_image , target ], training = True )
249
288
disc_generated_output = self .discriminator ([input_image , gen_output ], training = True )
250
289
251
- gen_total_loss , gen_gan_loss , gen_l1_loss = self .generator_loss (disc_generated_output , gen_output , target )
252
- disc_loss = self .discriminator_loss (disc_real_output , disc_generated_output )
290
+ gen_total_loss , _ , _ = self .generator_loss (disc_generated_output , gen_output , target )
291
+ disc_loss , _ , _ = self .discriminator_loss (disc_real_output , disc_generated_output )
253
292
254
293
return gen_total_loss , disc_loss
255
294
@@ -273,29 +312,37 @@ def fit(self, train_ds, epochs, test_ds, steps_per_epoch: int):
273
312
gen_loss_train = []
274
313
gen_loss_val = []
275
314
disc_loss_train = []
315
+ disc_loss_real_train = []
316
+ disc_loss_gen_train = []
276
317
disc_loss_val = []
277
318
for epoch in range (epochs ):
278
319
#generate sample output
279
- if epoch % 5 == 0 :
320
+ if epoch % self . example_plot_frequency == 0 :
280
321
for example_input , example_target in test_ds .take (1 ):
281
322
self .generate_images (self .generator , example_input , example_target , epoch )
282
323
# Train
283
324
train_pbar = tqdm (train_ds .enumerate (), total = steps_per_epoch )
284
325
train_pbar .set_description ("Epoch {:05d}" .format (epoch + 1 ))
285
326
gen_loss_batches = []
286
327
disc_loss_batches = []
328
+ disc_real_loss_batches = []
329
+ disc_gen_loss_batches = []
287
330
for _ , (input_image , target ) in train_pbar :
288
- gen_loss , disc_loss = self .train_step (input_image ["factorData" ], target ["out_matrixData" ], epoch )
331
+ gen_loss , disc_loss , disc_real_loss , disc_gen_loss = self .train_step (input_image ["factorData" ], target ["out_matrixData" ], epoch )
289
332
gen_loss_batches .append (gen_loss )
290
333
disc_loss_batches .append (disc_loss )
334
+ disc_real_loss_batches .append (disc_real_loss )
335
+ disc_gen_loss_batches .append (disc_gen_loss )
291
336
if epoch == 0 :
292
337
train_pbar .set_postfix ( {"loss" : "{:.4f}" .format (gen_loss )} )
293
338
else :
294
339
train_pbar .set_postfix ( {"train loss" : "{:.4f}" .format (gen_loss ),
295
340
"val loss" : "{:.4f}" .format (gen_loss_val [- 1 ])} )
296
341
gen_loss_train .append (np .mean (gen_loss_batches ))
297
342
disc_loss_train .append (np .mean (disc_loss_batches ))
298
- del gen_loss_batches , disc_loss_batches
343
+ disc_loss_real_train .append (np .mean (disc_real_loss_batches ))
344
+ disc_loss_gen_train .append (np .mean (disc_gen_loss_batches ))
345
+ del gen_loss_batches , disc_loss_batches , disc_real_loss_batches , disc_gen_loss_batches
299
346
# Validation
300
347
gen_loss_batches = []
301
348
disc_loss_batches = []
@@ -308,13 +355,15 @@ def fit(self, train_ds, epochs, test_ds, steps_per_epoch: int):
308
355
del gen_loss_batches , disc_loss_batches , train_pbar
309
356
310
357
# saving (checkpoint) the model every 20 epochs
311
- if (epoch + 1 ) % 20 == 0 :
358
+ if (epoch + 1 ) % self . progress_plot_frequency == 0 :
312
359
#self.checkpoint.save(file_prefix = self.checkpoint_prefix)
313
360
#plot loss
314
- utils .plotLoss (pLossValueLists = [gen_loss_train , gen_loss_val , disc_loss_train , disc_loss_val ],
315
- pNameList = ["gen.loss train" , "gen.loss val" , "disc.loss train" , "disc.loss val" ],
316
- pFilename = self .progress_plot_name ,
317
- useLogscale = True )
361
+ utils .plotLoss (pGeneratorLossValueLists = [gen_loss_train , gen_loss_val ],
362
+ pDiscLossValueLists = [disc_loss_train , disc_loss_real_train , disc_loss_gen_train , disc_loss_val ],
363
+ pGeneratorLossNameList = ["training" , "validation" ],
364
+ pDiscLossNameList = ["train total" , "train real" , "train gen." , "valid. total" ],
365
+ pFilename = self .progress_plot_name ,
366
+ useLogscaleList = [True , False ])
318
367
np .savez (os .path .join (self .log_dir , "lossValues_{:05d}.npz" .format (epoch )),
319
368
genLossTrain = gen_loss_train ,
320
369
genLossVal = gen_loss_val ,
@@ -325,10 +374,12 @@ def fit(self, train_ds, epochs, test_ds, steps_per_epoch: int):
325
374
326
375
327
376
self .checkpoint .save (file_prefix = self .checkpoint_prefix )
328
- utils .plotLoss (pLossValueLists = [gen_loss_train , gen_loss_val , disc_loss_train , disc_loss_val ],
329
- pNameList = ["gen.loss train" , "gen.loss val" , "disc.loss train" , "disc.loss val" ],
330
- pFilename = self .progress_plot_name ,
331
- useLogscale = True )
377
+ utils .plotLoss (pGeneratorLossValueLists = [gen_loss_train , gen_loss_val ],
378
+ pDiscLossValueLists = [disc_loss_train , disc_loss_real_train , disc_loss_gen_train , disc_loss_val ],
379
+ pGeneratorLossNameList = ["training" , "validation" ],
380
+ pDiscLossNameList = ["train total" , "train real" , "train gen." , "valid. total" ],
381
+ pFilename = self .progress_plot_name ,
382
+ useLogscaleList = [True , False ])
332
383
np .savez (os .path .join (self .log_dir , "lossValues_{:05d}.npz" .format (epoch )),
333
384
genLossTrain = gen_loss_train ,
334
385
genLossVal = gen_loss_val ,
@@ -371,6 +422,7 @@ def loadGenerator(self, trainedModelPath: str):
371
422
raise ValueError (msg )
372
423
373
424
def loadIntroModel (self , trainedModelPath : str ):
425
+ '''load pretrained model for 1D-2D conversion as defined by Farre et al.'''
374
426
try :
375
427
introModel = tf .keras .models .load_model (filepath = trainedModelPath )
376
428
except Exception as e :
@@ -385,11 +437,8 @@ def loadIntroModel(self, trainedModelPath: str):
385
437
x = tf .keras .layers .Add ()([x , x_T , diag ])
386
438
out = tf .keras .layers .Reshape ((self .INPUT_SIZE , self .INPUT_SIZE , self .INPUT_CHANNELS ))(x )
387
439
introModel_gen = tf .keras .models .Model (inputs = inputs , outputs = out , name = "gen_intro_preloaded" )
388
- introModel_disc = tf .keras .models .Model (inputs = inputs , outputs = out , name = "disc_intro_preloaded" )
389
440
self .generator_intro_model = introModel_gen
390
- self .discriminator_intro_model = introModel_disc
391
- self .generator = self .Generator ()
392
- self .discriminator = self .Discriminator ()
441
+ self .generator = self .Generator ()
393
442
394
443
class CustomReshapeLayer (tf .keras .layers .Layer ):
395
444
'''
0 commit comments