Skip to content

Commit 50fb99a

Browse files
author
Ralf
committed
allow sample flipping / data augmentation
1 parent 0a8615c commit 50fb99a

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

records.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@ def _int64_feature(value):
2929
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
3030

3131
#write tfRecord to disk
32-
def writeTFRecord(pFilename, pRecordDict):
33-
if not isinstance(pFilename, str):
34-
return
35-
if not isinstance(pRecordDict, dict):
36-
return
32+
def writeTFRecord(pFilename: str, pRecordDict: dict):
3733
for key in pRecordDict:
3834
if not isinstance(pRecordDict[key], np.ndarray):
3935
return
@@ -51,3 +47,9 @@ def writeTFRecord(pFilename, pRecordDict):
5147
feature[key] = _bytes_feature( pRecordDict[key][i].flatten().tostring() )
5248
example = tf.train.Example(features=tf.train.Features(feature=feature))
5349
writer.write(example.SerializeToString())
50+
51+
def mirror_function(tensor1, tensor2):
52+
t1 = tf.reverse(tensor1, axis=[0])
53+
t2 = tf.transpose(tensor2, perm=(1,0,2))
54+
t2 = tf.image.rot90(t2, 2)
55+
return {"factorData": t1}, {"out_matrixData": t2}

training.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@
6464
type=click.FloatRange(min=1e-2, max=1.0),
6565
default=0.5, show_default=True,
6666
help="beta1 parameter for Adam optimizer")
67+
@click.option("--flipsamples", "-fs", required=False,
68+
type=bool, default=False, show_default=True,
69+
help="Flip training matrices and chromatin features (data augmentation)")
6770
@click.option("--pretrainedIntroModel", "-ptm", required=False,
6871
type=click.Path(exists=True, dir_okay=False, readable=True),
6972
help="pretrained model for 1D-2D conversion of inputs")
@@ -92,6 +95,7 @@ def training(trainmatrices,
9295
lossweighttv,
9396
learningrate,
9497
beta1,
98+
flipsamples,
9599
pretrainedintromodel,
96100
figuretype,
97101
recordsize):
@@ -210,8 +214,11 @@ def training(trainmatrices,
210214
num_parallel_reads=tf.data.experimental.AUTOTUNE,
211215
compression_type="GZIP")
212216
trainDs = trainDs.map(lambda x: records.parse_function(x, storedFeaturesDict), num_parallel_calls=tf.data.experimental.AUTOTUNE)
217+
if flipsamples:
218+
flippedDs = trainDs.map(lambda a,b: records.mirror_function(a["factorData"], b["out_matrixData"]))
219+
trainDs = trainDs.concatenate(flippedDs)
213220
trainDs = trainDs.shuffle(buffer_size=shuffleBufferSize, reshuffle_each_iteration=True)
214-
trainDs = trainDs.batch(batchsize, drop_remainder=False)
221+
trainDs = trainDs.batch(batchsize, drop_remainder=True)
215222
trainDs = trainDs.prefetch(tf.data.experimental.AUTOTUNE)
216223
#build the input streams for validation
217224
validationDs = tf.data.TFRecordDataset(valdataRecords,
@@ -221,6 +228,9 @@ def training(trainmatrices,
221228
validationDs = validationDs.batch(batchsize)
222229
validationDs = validationDs.prefetch(tf.data.experimental.AUTOTUNE)
223230

231+
steps_per_epoch = int( np.floor(nr_trainingSamples / batchsize) )
232+
if flipsamples:
233+
steps_per_epoch *= 2
224234
hicGanModel = hicGAN.HiCGAN(log_dir=outfolder,
225235
lambda_pixel=lossweightpixel,
226236
lambda_disc=lossweightdisc,
@@ -232,7 +242,7 @@ def training(trainmatrices,
232242
if pretrainedintromodel is not None:
233243
hicGanModel.loadIntroModel(trainedModelPath=pretrainedintromodel)
234244
hicGanModel.plotModels(outputpath=outfolder, figuretype=figuretype)
235-
hicGanModel.fit(train_ds=trainDs, epochs=epochs, test_ds=validationDs, steps_per_epoch=int( np.floor(nr_trainingSamples / batchsize) ))
245+
hicGanModel.fit(train_ds=trainDs, epochs=epochs, test_ds=validationDs, steps_per_epoch=steps_per_epoch)
236246

237247
for tfRecordfile in traindataRecords + valdataRecords:
238248
if os.path.exists(tfRecordfile):

0 commit comments

Comments
 (0)