Skip to content

Commit 52ac4a4

Browse files
author
Ralf
committed
make number of features a parameter
1 parent e86a3b5 commit 52ac4a4

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

hicGAN.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
class HiCGAN():
1515
def __init__(self, log_dir: str,
16+
number_factors: int,
1617
loss_weight_pixel: float = 100, #factor for L1/L2 loss, like Isola et al. 2017
1718
loss_weight_adversarial: float = 1.0, #factor for adversarial loss in generator
1819
loss_weight_discriminator: float = 0.5, #factor for disc loss, like Isola et al. 2017
@@ -30,10 +31,10 @@ def __init__(self, log_dir: str,
3031

3132
self.OUTPUT_CHANNELS = 1
3233
self.INPUT_CHANNELS = 1
33-
self.INPUT_SIZE = 256
34+
self.input_size = 256
3435
if input_size in [64,128,256]:
35-
self.INPUT_SIZE = input_size
36-
self.NR_FACTORS = 14
36+
self.input_size = input_size
37+
self.number_factors = number_factors
3738
self.loss_weight_pixel = loss_weight_pixel
3839
self.loss_weight_discriminator = loss_weight_discriminator
3940
self.loss_weight_adversarial = loss_weight_adversarial
@@ -90,7 +91,7 @@ def __init__(self, log_dir: str,
9091
self.__batch_counter = 0
9192

9293
def cnn_embedding(self, nr_filters_list=[1024,512,512,256,256,128,128,64], kernel_width_list=[4,4,4,4,4,4,4,4], apply_dropout: bool = False):
93-
inputs = tf.keras.layers.Input(shape=(3*self.INPUT_SIZE, self.NR_FACTORS))
94+
inputs = tf.keras.layers.Input(shape=(3*self.input_size, self.number_factors))
9495
#add 1D convolutions
9596
x = inputs
9697
for i, (nr_filters, kernelWidth) in enumerate(zip(nr_filters_list, kernel_width_list)):
@@ -108,7 +109,7 @@ def cnn_embedding(self, nr_filters_list=[1024,512,512,256,256,128,128,64], kerne
108109
x = Dropout(0.5)(x)
109110
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
110111
#make the shape of a square matrix
111-
x = Conv1D(filters=self.INPUT_SIZE,
112+
x = Conv1D(filters=self.input_size,
112113
strides=3,
113114
kernel_size=4,
114115
data_format="channels_last",
@@ -119,14 +120,14 @@ def cnn_embedding(self, nr_filters_list=[1024,512,512,256,256,128,128,64], kerne
119120
x = tf.keras.layers.Add()([x, x_T])
120121
x = tf.keras.layers.Lambda(lambda z: 0.5*z)(x) #add transpose and divide by 2
121122
#reshape the matrix into a 2D grayscale image
122-
x = tf.keras.layers.Reshape((self.INPUT_SIZE,self.INPUT_SIZE,self.INPUT_CHANNELS))(x)
123+
x = tf.keras.layers.Reshape((self.input_size,self.input_size,self.INPUT_CHANNELS))(x)
123124
model = tf.keras.Model(inputs=inputs, outputs=x, name="CNN-embedding")
124125
#model.build(input_shape=(3*self.INPUT_SIZE, self.NR_FACTORS))
125126
#model.summary()
126127
return model
127128

128129
def dnn_embedding(self, pretrained_model_path : str = ""):
129-
inputs = tf.keras.layers.Input(shape=(3*self.INPUT_SIZE, self.NR_FACTORS))
130+
inputs = tf.keras.layers.Input(shape=(3*self.input_size, self.number_factors))
130131
x = Conv1D(filters=1,
131132
kernel_size=1,
132133
strides=1,
@@ -139,7 +140,7 @@ def dnn_embedding(self, pretrained_model_path : str = ""):
139140
x = Dense(nr_neurons, activation="relu", kernel_regularizer="l2", name=layerName)(x)
140141
layerName = "dropout_" + str(i+1)
141142
x = Dropout(0.1, name=layerName)(x)
142-
nr_output_neurons = (self.INPUT_SIZE * (self.INPUT_SIZE + 1)) // 2
143+
nr_output_neurons = (self.input_size * (self.input_size + 1)) // 2
143144
x = Dense(nr_output_neurons, activation="relu",kernel_regularizer="l2", name="dense_out")(x)
144145
dnn_model = tf.keras.Model(inputs=inputs, outputs=x)
145146
if pretrained_model_path != "":
@@ -150,15 +151,15 @@ def dnn_embedding(self, pretrained_model_path : str = ""):
150151
msg = str(e)
151152
msg += "\nCould not load the weights of pre-trained model"
152153
print(msg)
153-
inputs2 = tf.keras.layers.Input(shape=(3*self.INPUT_SIZE, self.NR_FACTORS))
154+
inputs2 = tf.keras.layers.Input(shape=(3*self.input_size, self.number_factors))
154155
x = dnn_model(inputs2)
155156
#place the upper triangular part from dnn model into full matrix
156-
x = CustomReshapeLayer(self.INPUT_SIZE)(x)
157+
x = CustomReshapeLayer(self.input_size)(x)
157158
#symmetrize the output
158159
x_T = tf.keras.layers.Permute((2,1))(x)
159160
diag = tf.keras.layers.Lambda(lambda z: -1*tf.linalg.band_part(z, 0, 0))(x)
160161
x = tf.keras.layers.Add()([x, x_T, diag])
161-
out = tf.keras.layers.Reshape((self.INPUT_SIZE, self.INPUT_SIZE, self.INPUT_CHANNELS))(x)
162+
out = tf.keras.layers.Reshape((self.input_size, self.input_size, self.INPUT_CHANNELS))(x)
162163
dnn_embedding = tf.keras.Model(inputs=inputs2, outputs=out, name="DNN-embedding")
163164
return dnn_embedding
164165

@@ -189,7 +190,7 @@ def upsample(filters, size, apply_dropout=False):
189190

190191

191192
def Generator(self):
192-
inputs = tf.keras.layers.Input(shape=[3*self.INPUT_SIZE,self.NR_FACTORS], name="factorData")
193+
inputs = tf.keras.layers.Input(shape=[3*self.input_size,self.number_factors], name="factorData")
193194

194195
twoD_conversion = self.generator_embedding
195196
#the downsampling part of the network, defined for 256x256 images
@@ -204,9 +205,9 @@ def Generator(self):
204205
HiCGAN.downsample(512, 4, apply_batchnorm=False), # (bs, 1, 1, 512)
205206
]
206207
#if the input images are smaller, leave out some layers accordingly
207-
if self.INPUT_SIZE < 256:
208+
if self.input_size < 256:
208209
down_stack = down_stack[:-2] + down_stack[-1:]
209-
if self.INPUT_SIZE < 128:
210+
if self.input_size < 128:
210211
down_stack = down_stack[:-2] + down_stack[-1:]
211212

212213
#the upsampling portion of the generator, designed for 256x256 images
@@ -220,9 +221,9 @@ def Generator(self):
220221
HiCGAN.upsample(64, 4), # (bs, 128, 128, 128)
221222
]
222223
#for smaller images, take layers away, otherwise downsampling won't work
223-
if self.INPUT_SIZE < 256:
224+
if self.input_size < 256:
224225
up_stack = up_stack[:2] + up_stack[3:]
225-
if self.INPUT_SIZE < 128:
226+
if self.input_size < 128:
226227
up_stack = up_stack[:2] + up_stack[3:]
227228

228229
initializer = tf.random_normal_initializer(0., 0.02)
@@ -272,13 +273,13 @@ def generator_loss(self, disc_generated_output, gen_output, target):
272273
def Discriminator(self):
273274
initializer = tf.random_normal_initializer(0., 0.02)
274275

275-
inp = tf.keras.layers.Input(shape=[3*self.INPUT_SIZE, self.NR_FACTORS], name='input_image')
276-
tar = tf.keras.layers.Input(shape=[self.INPUT_SIZE, self.INPUT_SIZE, self.OUTPUT_CHANNELS], name='target_image')
276+
inp = tf.keras.layers.Input(shape=[3*self.input_size, self.number_factors], name='input_image')
277+
tar = tf.keras.layers.Input(shape=[self.input_size, self.input_size, self.OUTPUT_CHANNELS], name='target_image')
277278
embedding = self.discriminator_embedding
278279
#Patch-GAN (Isola et al.)
279280
d = embedding(inp)
280281
d = tf.keras.layers.Concatenate()([d, tar])
281-
if self.INPUT_SIZE > 64:
282+
if self.input_size > 64:
282283
#downsample and symmetrize 1
283284
d = HiCGAN.downsample(64, 4, False)(d) # (bs, inp.size/2, inp.size/2, 64)
284285
d_T = tf.keras.layers.Permute((2,1,3))(d)
@@ -485,7 +486,7 @@ def loadGenerator(self, trainedModelPath: str):
485486
'''
486487
try:
487488
trainedModel = tf.keras.models.load_model(filepath=trainedModelPath,
488-
custom_objects={"CustomReshapeLayer": CustomReshapeLayer(self.INPUT_SIZE)})
489+
custom_objects={"CustomReshapeLayer": CustomReshapeLayer(self.input_size)})
489490
self.generator = trainedModel
490491
except Exception as e:
491492
msg = str(e)

predict.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def prediction(trainedmodel,
6666
print(msg)
6767
return #nothing to do
6868
container0 = testdataContainerList[0]
69+
nr_factors = container0.nr_factors
6970
tfRecordFilenames = []
7071
sampleSizeList = []
7172
for container in testdataContainerList:
@@ -77,7 +78,7 @@ def prediction(trainedmodel,
7778
sampleSizeList.append( int( np.ceil(container.getNumberSamples() / batchSizeInt) ) )
7879
container.unloadData()
7980

80-
trained_GAN = hicGAN.HiCGAN(log_dir=outfolder)
81+
trained_GAN = hicGAN.HiCGAN(log_dir=outfolder, number_factors=nr_factors)
8182
trained_GAN.loadGenerator(trainedModelPath=trainedmodel)
8283
predList = []
8384
for record, container, nr_samples in zip(tfRecordFilenames, testdataContainerList, sampleSizeList):

training.py

+1
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def training(trainmatrices,
257257
if pretrainedintromodel is None:
258258
pretrainedintromodel = ""
259259
hicGanModel = hicGAN.HiCGAN(log_dir=outfolder,
260+
number_factors=nr_factors,
260261
loss_weight_pixel=lossweightpixel,
261262
loss_weight_adversarial=lossweightadv,
262263
loss_weight_discriminator=lossweightdisc,

0 commit comments

Comments
 (0)