Skip to content

Commit 7574a60

Browse files
author
Ralf
committed
Merge commit '5a1ef31f11ea83d7e8b7b065b9f19a03439349a4' into currentOpt
2 parents 8a21ee2 + 5a1ef31 commit 7574a60

File tree

5 files changed

+191
-70
lines changed

5 files changed

+191
-70
lines changed

hicGAN.py

+90-41
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212

1313
class HiCGAN():
1414
def __init__(self, log_dir: str,
15-
lambda_pixel: float = 1e-5,
16-
loss_type_pixel: str = "L2",
15+
lambda_pixel: float = 100,
16+
lambda_disc: float = 0.5,
17+
loss_type_pixel: str = "L1",
1718
tv_weight: float = 1e-10,
18-
input_size: int = 256):
19+
input_size: int = 256,
20+
plot_frequency: int = 20,
21+
learning_rate: float = 2e-5,
22+
adam_beta_1: float = 0.5):
1923
super().__init__()
2024

2125
self.OUTPUT_CHANNELS = 1
@@ -24,15 +28,15 @@ def __init__(self, log_dir: str,
2428
if input_size in [64,128,256]:
2529
self.INPUT_SIZE = input_size
2630
self.NR_FACTORS = 14
27-
self.LAMBDA = lambda_pixel
31+
self.lambda_pixel = lambda_pixel
32+
self.lambda_disc = lambda_disc
2833
self.tv_loss_Weight = tv_weight
2934
self.loss_type_pixel = loss_type_pixel
3035
self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
31-
self.generator_optimizer = tf.keras.optimizers.Adam(2e-5, beta_1=0.5)
32-
self.discriminator_optimizer = tf.keras.optimizers.Adam(2e-5, beta_1=0.5)
36+
self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=adam_beta_1, name="Adam_Generator")
37+
self.discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=adam_beta_1, name="Adam_Discriminator")
3338

3439
self.generator_intro_model = self.oneD_twoD_conversion()
35-
self.discriminator_intro_model = self.oneD_twoD_conversion()
3640
self.generator = self.Generator()
3741
self.discriminator = self.Discriminator()
3842

@@ -46,8 +50,10 @@ def __init__(self, log_dir: str,
4650
discriminator=self.discriminator)
4751

4852
self.progress_plot_name = os.path.join(self.log_dir, "lossOverEpochs.png")
53+
self.progress_plot_frequency = plot_frequency
54+
self.example_plot_frequency = 5
4955

50-
def oneD_twoD_conversion(self, nr_filters_list=[16,16,32,32,64], kernel_width_list=[4,4,4,4,4], nr_neurons_List=[5000,4000,3000]):
56+
def oneD_twoD_conversion(self, nr_filters_list=[1024,512,512,256,256,128,128,64], kernel_width_list=[4,4,4,4,4,4,4,4], apply_dropout: bool = False):
5157
inputs = tf.keras.layers.Input(shape=(3*self.INPUT_SIZE, self.NR_FACTORS))
5258
#add 1D convolutions
5359
x = inputs
@@ -57,17 +63,26 @@ def oneD_twoD_conversion(self, nr_filters_list=[16,16,32,32,64], kernel_width_li
5763
convParamDict["filters"] = nr_filters
5864
convParamDict["kernel_size"] = kernelWidth
5965
convParamDict["data_format"]="channels_last"
66+
convParamDict["kernel_regularizer"]=tf.keras.regularizers.l2(0.01)
6067
if kernelWidth > 1:
6168
convParamDict["padding"] = "same"
6269
x = Conv1D(**convParamDict)(x)
6370
x = BatchNormalization()(x)
64-
x = tf.keras.layers.Activation("sigmoid")(x)
65-
#make the shape of a 2D-image
66-
x = Conv1D(filters=self.INPUT_SIZE, strides=3, kernel_size=4, data_format="channels_last", activation="sigmoid", padding="same", name="conv1D_final")(x)
67-
y = tf.keras.layers.Permute((2,1))(x)
68-
diag = tf.keras.layers.Lambda(lambda z: -1*tf.linalg.band_part(z, 0, 0))(x)
69-
x = tf.keras.layers.Add()([x, y, diag])
70-
71+
if apply_dropout:
72+
x = Dropout(0.5)(x)
73+
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
74+
#make the shape of a square matrix
75+
x = Conv1D(filters=self.INPUT_SIZE,
76+
strides=3,
77+
kernel_size=4,
78+
data_format="channels_last",
79+
activation="sigmoid",
80+
padding="same", name="conv1D_final")(x)
81+
#ensure the matrix is symmetric, i.e. x = transpose(x)
82+
x_T = tf.keras.layers.Permute((2,1))(x) #this is the matrix transpose
83+
x = tf.keras.layers.Add()([x, x_T])
84+
x = tf.keras.layers.Lambda(lambda z: 0.5*z)(x) #add transpose and divide by 2
85+
#reshape the matrix into a 2D grayscale image
7186
x = tf.keras.layers.Reshape((self.INPUT_SIZE,self.INPUT_SIZE,self.INPUT_CHANNELS))(x)
7287
model = tf.keras.Model(inputs=inputs, outputs=x, name="crazy_intro_model")
7388
#model.build(input_shape=(3*self.INPUT_SIZE, self.NR_FACTORS))
@@ -141,8 +156,7 @@ def Generator(self):
141156
last = tf.keras.layers.Conv2DTranspose(self.OUTPUT_CHANNELS, 4,
142157
strides=2,
143158
padding='same',
144-
kernel_initializer=initializer,
145-
activation='sigmoid') # (bs, 256, 256, 3)
159+
kernel_initializer=initializer) # (bs, 256, 256, 3)
146160

147161
x = inputs
148162
x = twoD_conversion(x)
@@ -161,8 +175,11 @@ def Generator(self):
161175
x = tf.keras.layers.Concatenate()([x, skip])
162176

163177
x = last(x)
178+
#enforce symmetry
164179
x_T = tf.keras.layers.Permute((2,1,3))(x)
165180
x = tf.keras.layers.Add()([x, x_T])
181+
x = tf.keras.layers.Lambda(lambda z: 0.5*z)(x)
182+
x = tf.keras.layers.Activation("sigmoid")(x)
166183

167184
return tf.keras.Model(inputs=inputs, outputs=x)
168185

@@ -175,7 +192,7 @@ def generator_loss(self, disc_generated_output, gen_output, target):
175192
else:
176193
pixel_loss = tf.reduce_mean(tf.square(target - gen_output))
177194
tv_loss = tf.reduce_mean(tf.image.total_variation(gen_output))
178-
total_gen_loss = pixel_loss + 1/self.LAMBDA * gan_loss + self.tv_loss_Weight * tv_loss
195+
total_gen_loss = self.lambda_pixel * pixel_loss + self.lambda_disc * gan_loss + self.tv_loss_Weight * tv_loss
179196
return total_gen_loss, gan_loss, pixel_loss
180197

181198

@@ -184,7 +201,7 @@ def Discriminator(self):
184201

185202
inp = tf.keras.layers.Input(shape=[3*self.INPUT_SIZE, self.NR_FACTORS], name='input_image')
186203
tar = tf.keras.layers.Input(shape=[self.INPUT_SIZE, self.INPUT_SIZE, self.OUTPUT_CHANNELS], name='target_image')
187-
twoD_conversion = self.discriminator_intro_model
204+
twoD_conversion = self.oneD_twoD_conversion()
188205
#x = Flatten()(inp)
189206
#x = Dense(units = self.INPUT_SIZE*(self.INPUT_SIZE+1)//2)(x)
190207
#x = tf.keras.layers.LeakyReLU()(x)
@@ -198,24 +215,46 @@ def Discriminator(self):
198215
d = twoD_conversion(inp)
199216
d = tf.keras.layers.Concatenate()([d, tar])
200217
if self.INPUT_SIZE > 64:
218+
#downsample and symmetrize 1
201219
d = HiCGAN.downsample(64, 4, False)(d) # (bs, inp.size/2, inp.size/2, 64)
220+
d_T = tf.keras.layers.Permute((2,1,3))(d)
221+
d = tf.keras.layers.Add()([d, d_T])
222+
d = tf.keras.layers.Lambda(lambda z: 0.5*z)(d)
223+
#downsample and symmetrize 2
202224
d = HiCGAN.downsample(128, 4)(d)# (bs, inp.size/4, inp.size/4, 128)
225+
d_T = tf.keras.layers.Permute((2,1,3))(d)
226+
d = tf.keras.layers.Add()([d, d_T])
227+
d = tf.keras.layers.Lambda(lambda z: 0.5*z)(d)
203228
else:
229+
#downsample and symmetrize 3
204230
d = HiCGAN.downsample(256, 4)(d)
231+
d_T = tf.keras.layers.Permute((2,1,3))(d)
232+
d = tf.keras.layers.Add()([d, d_T])
233+
d = tf.keras.layers.Lambda(lambda z: 0.5*z)(d)
234+
#downsample and symmetrize 4
205235
d = HiCGAN.downsample(256, 4)(d) # (bs, inp.size/8, inp.size/8, 256)
236+
d_T = tf.keras.layers.Permute((2,1,3))(d)
237+
d = tf.keras.layers.Add()([d, d_T])
238+
d = tf.keras.layers.Lambda(lambda z: 0.5*z)(d)
206239
d = Conv2D(512, 4, strides=1, padding="same", kernel_initializer=initializer)(d) #(bs, inp.size/8, inp.size/8, 512)
240+
d_T = tf.keras.layers.Permute((2,1,3))(d)
241+
d = tf.keras.layers.Add()([d, d_T])
242+
d = tf.keras.layers.Lambda(lambda z: 0.5*z)(d)
207243
d = BatchNormalization()(d)
208244
d = LeakyReLU(alpha=0.2)(d)
209245
d = Conv2D(1, 4, strides=1, padding="same",
210246
kernel_initializer=initializer)(d) #(bs, inp.size/8, inp.size/8, 1)
247+
d_T = tf.keras.layers.Permute((2,1,3))(d)
248+
d = tf.keras.layers.Add()([d, d_T])
249+
d = tf.keras.layers.Lambda(lambda z: 0.5*z)(d)
211250
d = tf.keras.layers.Activation("sigmoid")(d)
212251
return tf.keras.Model(inputs=[inp, tar], outputs=d)
213252

214253
def discriminator_loss(self, disc_real_output, disc_generated_output):
215254
real_loss = self.loss_object(tf.ones_like(disc_real_output), disc_real_output)
216255
generated_loss = self.loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
217256
total_disc_loss = real_loss + generated_loss
218-
return total_disc_loss
257+
return total_disc_loss, real_loss, generated_loss
219258

220259

221260
@tf.function
@@ -226,8 +265,8 @@ def train_step(self, input_image, target, epoch):
226265
disc_real_output = self.discriminator([input_image, target], training=True)
227266
disc_generated_output = self.discriminator([input_image, gen_output], training=True)
228267

229-
gen_total_loss, gen_gan_loss, gen_l1_loss = self.generator_loss(disc_generated_output, gen_output, target)
230-
disc_loss = self.discriminator_loss(disc_real_output, disc_generated_output)
268+
gen_total_loss, _, _ = self.generator_loss(disc_generated_output, gen_output, target)
269+
disc_loss, disc_real_loss, disc_gen_loss = self.discriminator_loss(disc_real_output, disc_generated_output)
231270

232271
generator_gradients = gen_tape.gradient(gen_total_loss,
233272
self.generator.trainable_variables)
@@ -239,7 +278,7 @@ def train_step(self, input_image, target, epoch):
239278
self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
240279
self.discriminator.trainable_variables))
241280

242-
return gen_total_loss, disc_loss
281+
return gen_total_loss, disc_loss, disc_real_loss, disc_gen_loss
243282

244283
@tf.function
245284
def validationStep(self, input_image, target, epoch):
@@ -248,8 +287,8 @@ def validationStep(self, input_image, target, epoch):
248287
disc_real_output = self.discriminator([input_image, target], training=True)
249288
disc_generated_output = self.discriminator([input_image, gen_output], training=True)
250289

251-
gen_total_loss, gen_gan_loss, gen_l1_loss = self.generator_loss(disc_generated_output, gen_output, target)
252-
disc_loss = self.discriminator_loss(disc_real_output, disc_generated_output)
290+
gen_total_loss, _, _ = self.generator_loss(disc_generated_output, gen_output, target)
291+
disc_loss, _, _ = self.discriminator_loss(disc_real_output, disc_generated_output)
253292

254293
return gen_total_loss, disc_loss
255294

@@ -273,29 +312,37 @@ def fit(self, train_ds, epochs, test_ds, steps_per_epoch: int):
273312
gen_loss_train = []
274313
gen_loss_val = []
275314
disc_loss_train =[]
315+
disc_loss_real_train = []
316+
disc_loss_gen_train = []
276317
disc_loss_val = []
277318
for epoch in range(epochs):
278319
#generate sample output
279-
if epoch % 5 == 0:
320+
if epoch % self.example_plot_frequency == 0:
280321
for example_input, example_target in test_ds.take(1):
281322
self.generate_images(self.generator, example_input, example_target, epoch)
282323
# Train
283324
train_pbar = tqdm(train_ds.enumerate(), total=steps_per_epoch)
284325
train_pbar.set_description("Epoch {:05d}".format(epoch+1))
285326
gen_loss_batches = []
286327
disc_loss_batches = []
328+
disc_real_loss_batches = []
329+
disc_gen_loss_batches = []
287330
for _, (input_image, target) in train_pbar:
288-
gen_loss, disc_loss = self.train_step(input_image["factorData"], target["out_matrixData"], epoch)
331+
gen_loss, disc_loss, disc_real_loss, disc_gen_loss = self.train_step(input_image["factorData"], target["out_matrixData"], epoch)
289332
gen_loss_batches.append(gen_loss)
290333
disc_loss_batches.append(disc_loss)
334+
disc_real_loss_batches.append(disc_real_loss)
335+
disc_gen_loss_batches.append(disc_gen_loss)
291336
if epoch == 0:
292337
train_pbar.set_postfix( {"loss": "{:.4f}".format(gen_loss)} )
293338
else:
294339
train_pbar.set_postfix( {"train loss": "{:.4f}".format(gen_loss),
295340
"val loss": "{:.4f}".format(gen_loss_val[-1])} )
296341
gen_loss_train.append(np.mean(gen_loss_batches))
297342
disc_loss_train.append(np.mean(disc_loss_batches))
298-
del gen_loss_batches, disc_loss_batches
343+
disc_loss_real_train.append(np.mean(disc_real_loss_batches))
344+
disc_loss_gen_train.append(np.mean(disc_gen_loss_batches))
345+
del gen_loss_batches, disc_loss_batches, disc_real_loss_batches, disc_gen_loss_batches
299346
# Validation
300347
gen_loss_batches = []
301348
disc_loss_batches = []
@@ -308,13 +355,15 @@ def fit(self, train_ds, epochs, test_ds, steps_per_epoch: int):
308355
del gen_loss_batches, disc_loss_batches, train_pbar
309356

310357
# saving (checkpoint) the model every 20 epochs
311-
if (epoch + 1) % 20 == 0:
358+
if (epoch + 1) % self.progress_plot_frequency == 0:
312359
#self.checkpoint.save(file_prefix = self.checkpoint_prefix)
313360
#plot loss
314-
utils.plotLoss(pLossValueLists=[gen_loss_train, gen_loss_val, disc_loss_train, disc_loss_val],
315-
pNameList=["gen.loss train", "gen.loss val", "disc.loss train", "disc.loss val"],
316-
pFilename=self.progress_plot_name,
317-
useLogscale=True)
361+
utils.plotLoss(pGeneratorLossValueLists=[gen_loss_train, gen_loss_val],
362+
pDiscLossValueLists=[disc_loss_train, disc_loss_real_train, disc_loss_gen_train, disc_loss_val],
363+
pGeneratorLossNameList=["training", "validation"],
364+
pDiscLossNameList=["train total", "train real", "train gen.", "valid. total"],
365+
pFilename=self.progress_plot_name,
366+
useLogscaleList=[True, False])
318367
np.savez(os.path.join(self.log_dir, "lossValues_{:05d}.npz".format(epoch)),
319368
genLossTrain=gen_loss_train,
320369
genLossVal=gen_loss_val,
@@ -325,10 +374,12 @@ def fit(self, train_ds, epochs, test_ds, steps_per_epoch: int):
325374

326375

327376
self.checkpoint.save(file_prefix = self.checkpoint_prefix)
328-
utils.plotLoss(pLossValueLists=[gen_loss_train, gen_loss_val, disc_loss_train, disc_loss_val],
329-
pNameList=["gen.loss train", "gen.loss val", "disc.loss train", "disc.loss val"],
330-
pFilename=self.progress_plot_name,
331-
useLogscale=True)
377+
utils.plotLoss(pGeneratorLossValueLists=[gen_loss_train, gen_loss_val],
378+
pDiscLossValueLists=[disc_loss_train, disc_loss_real_train, disc_loss_gen_train, disc_loss_val],
379+
pGeneratorLossNameList=["training", "validation"],
380+
pDiscLossNameList=["train total", "train real", "train gen.", "valid. total"],
381+
pFilename=self.progress_plot_name,
382+
useLogscaleList=[True, False])
332383
np.savez(os.path.join(self.log_dir, "lossValues_{:05d}.npz".format(epoch)),
333384
genLossTrain=gen_loss_train,
334385
genLossVal=gen_loss_val,
@@ -371,6 +422,7 @@ def loadGenerator(self, trainedModelPath: str):
371422
raise ValueError(msg)
372423

373424
def loadIntroModel(self, trainedModelPath: str):
425+
'''load pretrained model for 1D-2D conversion as defined by Farre et al.'''
374426
try:
375427
introModel = tf.keras.models.load_model(filepath=trainedModelPath)
376428
except Exception as e:
@@ -385,11 +437,8 @@ def loadIntroModel(self, trainedModelPath: str):
385437
x = tf.keras.layers.Add()([x, x_T, diag])
386438
out = tf.keras.layers.Reshape((self.INPUT_SIZE, self.INPUT_SIZE, self.INPUT_CHANNELS))(x)
387439
introModel_gen = tf.keras.models.Model(inputs=inputs, outputs=out, name="gen_intro_preloaded")
388-
introModel_disc = tf.keras.models.Model(inputs=inputs, outputs=out, name="disc_intro_preloaded")
389440
self.generator_intro_model = introModel_gen
390-
self.discriminator_intro_model = introModel_disc
391-
self.generator = self.Generator()
392-
self.discriminator = self.Discriminator()
441+
self.generator = self.Generator()
393442

394443
class CustomReshapeLayer(tf.keras.layers.Layer):
395444
'''

predict.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import click
22
import numpy as np
33
import os
4+
import csv
45
import tensorflow as tf
56
import dataContainer
67
import records
@@ -20,11 +21,15 @@
2021
type=click.Path(exists=True, writable=True),
2122
default="./", show_default=True,
2223
help="Output path for predicted coolers")
24+
@click.option("--multiplier", "-mul", required=False,
25+
type=click.IntRange(min=1),
26+
default=10, show_default=True)
2327
@click.command()
2428
def prediction(trainedmodel,
2529
testchrompath,
2630
testchroms,
27-
outfolder
31+
outfolder,
32+
multiplier
2833
):
2934
binSizeInt = 25000
3035
scalefactors = True
@@ -34,9 +39,9 @@ def prediction(trainedmodel,
3439
flankingsize = windowsize
3540
maxdist = None
3641
batchSizeInt = 32
37-
multiplier = 10
3842

39-
43+
paramDict = locals().copy()
44+
4045
#extract chromosome names from the input
4146
chromNameList = testchroms.replace(",", " ").rstrip().split(" ")
4247
chromNameList = sorted([x.lstrip("chr") for x in chromNameList])
@@ -98,6 +103,15 @@ def prediction(trainedmodel,
98103
pOutfile=matrixname,
99104
pChromosomeList=chromNameList)
100105

106+
parameterFile = os.path.join(outfolder, "predParams.csv")
107+
with open(parameterFile, "w") as csvfile:
108+
dictWriter = csv.DictWriter(csvfile, fieldnames=sorted(list(paramDict.keys())))
109+
dictWriter.writeheader()
110+
dictWriter.writerow(paramDict)
111+
112+
for tfrecordfile in tfRecordFilenames:
113+
if os.path.exists(tfrecordfile):
114+
os.remove(tfrecordfile)
101115

102116
if __name__ == "__main__":
103117
prediction() #pylint: disable=no-value-for-parameter

records.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@ def _int64_feature(value):
2929
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
3030

3131
#write tfRecord to disk
32-
def writeTFRecord(pFilename, pRecordDict):
33-
if not isinstance(pFilename, str):
34-
return
35-
if not isinstance(pRecordDict, dict):
36-
return
32+
def writeTFRecord(pFilename: str, pRecordDict: dict):
3733
for key in pRecordDict:
3834
if not isinstance(pRecordDict[key], np.ndarray):
3935
return
@@ -51,3 +47,9 @@ def writeTFRecord(pFilename, pRecordDict):
5147
feature[key] = _bytes_feature( pRecordDict[key][i].flatten().tostring() )
5248
example = tf.train.Example(features=tf.train.Features(feature=feature))
5349
writer.write(example.SerializeToString())
50+
51+
def mirror_function(tensor1, tensor2):
52+
t1 = tf.reverse(tensor1, axis=[0])
53+
t2 = tf.transpose(tensor2, perm=(1,0,2))
54+
t2 = tf.image.rot90(t2, 2)
55+
return {"factorData": t1}, {"out_matrixData": t2}

0 commit comments

Comments
 (0)