Skip to content

Commit 3bbe7fd

Browse files
author
Ralf
committed
batchsize, windowsize as parameters
1 parent a10360b commit 3bbe7fd

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

predict.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,29 @@
2727
@click.option("--binsize", "-b", required=True,
2828
type=click.IntRange(min=1000),
2929
help="bin size for binning the chromatin features")
30+
@click.option("--batchsize", "-bs", required=False,
31+
type=click.IntRange(min=1),
32+
default=32, show_default=True,
33+
help="batchsize for predicting")
34+
@click.option("--windowsize", "-ws", required=True,
35+
type=click.Choice(choices=["64", "128", "256"]),
36+
help="windowsize for predicting; must be the same as in trained model. Supported values are 64, 128 and 256")
3037
@click.command()
3138
def prediction(trainedmodel,
3239
testchrompath,
3340
testchroms,
3441
outfolder,
3542
multiplier,
36-
binsize
43+
binsize,
44+
batchsize,
45+
windowsize
3746
):
3847
scalefactors = True
3948
clampfactors = False
4049
scalematrix = True
41-
windowsize = 64
42-
flankingsize = windowsize
4350
maxdist = None
44-
batchSizeInt = 32
51+
windowsize = int(windowsize)
52+
flankingsize = windowsize
4553

4654
paramDict = locals().copy()
4755

@@ -79,7 +87,7 @@ def prediction(trainedmodel,
7987
raise SystemExit(msg)
8088
tfRecordFilenames.append(container.writeTFRecord(pOutfolder=outfolder,
8189
pRecordSize=None)[0]) #list with 1 entry
82-
sampleSizeList.append( int( np.ceil(container.getNumberSamples() / batchSizeInt) ) )
90+
sampleSizeList.append( int( np.ceil(container.getNumberSamples() / batchsize) ) )
8391

8492
nr_factors = container0.nr_factors
8593
#data is no longer needed, unload it
@@ -95,7 +103,7 @@ def prediction(trainedmodel,
95103
num_parallel_reads=None,
96104
compression_type="GZIP")
97105
testDs = testDs.map(lambda x: records.parse_function(x, storedFeaturesDict), num_parallel_calls=tf.data.experimental.AUTOTUNE)
98-
testDs = testDs.batch(batchSizeInt, drop_remainder=False) #do NOT drop the last batch (maybe incomplete, i.e. smaller, because batch size doesn't integer divide chrom size)
106+
testDs = testDs.batch(batchsize, drop_remainder=False) #do NOT drop the last batch (maybe incomplete, i.e. smaller, because batch size doesn't integer divide chrom size)
99107
#if validationmatrix is not None:
100108
# testDs = testDs.map(lambda x, y: x) #drop the target matrices (they are for evaluation)
101109
testDs = testDs.prefetch(tf.data.experimental.AUTOTUNE)

0 commit comments

Comments
 (0)