64
64
type = click .FloatRange (min = 1e-2 , max = 1.0 ),
65
65
default = 0.5 , show_default = True ,
66
66
help = "beta1 parameter for Adam optimizer" )
67
+ @click .option ("--flipsamples" , "-fs" , required = False ,
68
+ type = bool , default = False , show_default = True ,
69
+ help = "Flip training matrices and chromatin features (data augmentation)" )
67
70
@click .option ("--pretrainedIntroModel" , "-ptm" , required = False ,
68
71
type = click .Path (exists = True , dir_okay = False , readable = True ),
69
72
help = "pretrained model for 1D-2D conversion of inputs" )
@@ -92,6 +95,7 @@ def training(trainmatrices,
92
95
lossweighttv ,
93
96
learningrate ,
94
97
beta1 ,
98
+ flipsamples ,
95
99
pretrainedintromodel ,
96
100
figuretype ,
97
101
recordsize ):
@@ -210,8 +214,11 @@ def training(trainmatrices,
210
214
num_parallel_reads = tf .data .experimental .AUTOTUNE ,
211
215
compression_type = "GZIP" )
212
216
trainDs = trainDs .map (lambda x : records .parse_function (x , storedFeaturesDict ), num_parallel_calls = tf .data .experimental .AUTOTUNE )
217
+ if flipsamples :
218
+ flippedDs = trainDs .map (lambda a ,b : records .mirror_function (a ["factorData" ], b ["out_matrixData" ]))
219
+ trainDs = trainDs .concatenate (flippedDs )
213
220
trainDs = trainDs .shuffle (buffer_size = shuffleBufferSize , reshuffle_each_iteration = True )
214
- trainDs = trainDs .batch (batchsize , drop_remainder = False )
221
+ trainDs = trainDs .batch (batchsize , drop_remainder = True )
215
222
trainDs = trainDs .prefetch (tf .data .experimental .AUTOTUNE )
216
223
#build the input streams for validation
217
224
validationDs = tf .data .TFRecordDataset (valdataRecords ,
@@ -221,6 +228,9 @@ def training(trainmatrices,
221
228
validationDs = validationDs .batch (batchsize )
222
229
validationDs = validationDs .prefetch (tf .data .experimental .AUTOTUNE )
223
230
231
+ steps_per_epoch = int ( np .floor (nr_trainingSamples / batchsize ) )
232
+ if flipsamples :
233
+ steps_per_epoch *= 2
224
234
hicGanModel = hicGAN .HiCGAN (log_dir = outfolder ,
225
235
lambda_pixel = lossweightpixel ,
226
236
lambda_disc = lossweightdisc ,
@@ -232,7 +242,7 @@ def training(trainmatrices,
232
242
if pretrainedintromodel is not None :
233
243
hicGanModel .loadIntroModel (trainedModelPath = pretrainedintromodel )
234
244
hicGanModel .plotModels (outputpath = outfolder , figuretype = figuretype )
235
- hicGanModel .fit (train_ds = trainDs , epochs = epochs , test_ds = validationDs , steps_per_epoch = int ( np . floor ( nr_trainingSamples / batchsize ) ) )
245
+ hicGanModel .fit (train_ds = trainDs , epochs = epochs , test_ds = validationDs , steps_per_epoch = steps_per_epoch )
236
246
237
247
for tfRecordfile in traindataRecords + valdataRecords :
238
248
if os .path .exists (tfRecordfile ):
0 commit comments