forked from HasnainRaz/Fast-SRGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
239 lines (199 loc) · 9.91 KB
/
model.py
File metadata and controls
239 lines (199 loc) · 9.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from tensorflow import keras
import tensorflow as tf
class FastSRGAN(object):
"""SRGAN for fast super resolution."""
def __init__(self, args):
"""
Initializes the Mobile SRGAN class.
Args:
args: CLI arguments that dictate how to build the model.
Returns:
None
"""
self.hr_height = args.hr_size
self.hr_width = args.hr_size
self.lr_height = self.hr_height // 4 # Low resolution height
self.lr_width = self.hr_width // 4 # Low resolution width
self.lr_shape = (self.lr_height, self.lr_width, 3)
self.hr_shape = (self.hr_height, self.hr_width, 3)
self.iterations = 0
# Number of inverted residual blocks in the mobilenet generator
self.n_residual_blocks = 6
# Define a learning rate decay schedule.
self.gen_schedule = keras.optimizers.schedules.ExponentialDecay(
args.lr,
decay_steps=100000,
decay_rate=0.1,
staircase=True
)
self.disc_schedule = keras.optimizers.schedules.ExponentialDecay(
args.lr * 5, # TTUR - Two Time Scale Updates
decay_steps=100000,
decay_rate=0.1,
staircase=True
)
self.gen_optimizer = keras.optimizers.Adam(learning_rate=self.gen_schedule)
self.disc_optimizer = keras.optimizers.Adam(learning_rate=self.disc_schedule)
# We use a pre-trained VGG19 model to extract image features from the high resolution
# and the generated high resolution images and minimize the mse between them
self.vgg = self.build_vgg()
self.vgg.trainable = False
# Calculate output shape of D (PatchGAN)
patch = int(self.hr_height / 2 ** 4)
self.disc_patch = (patch, patch, 1)
# Number of filters in the first layer of G and D
self.gf = 32 # Realtime Image Enhancement GAN Galteri et al.
self.df = 32
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
# Build and compile the generator for pretraining.
self.generator = self.build_generator()
@tf.function
def content_loss(self, hr, sr):
sr = keras.applications.vgg19.preprocess_input(((sr + 1.0) * 255) / 2.0)
hr = keras.applications.vgg19.preprocess_input(((hr + 1.0) * 255) / 2.0)
sr_features = self.vgg(sr) / 12.75
hr_features = self.vgg(hr) / 12.75
return tf.keras.losses.MeanSquaredError()(hr_features, sr_features)
def build_vgg(self):
"""
Builds a pre-trained VGG19 model that outputs image features extracted at the
third block of the model
"""
# Get the vgg network. Extract features from Block 5, last convolution.
vgg = keras.applications.VGG19(weights="imagenet", input_shape=self.hr_shape, include_top=False)
vgg.trainable = False
for layer in vgg.layers:
layer.trainable = False
# Create model and compile
model = keras.models.Model(inputs=vgg.input, outputs=vgg.get_layer("block5_conv4").output)
return model
def build_generator(self):
"""Build the generator that will do the Super Resolution task.
Based on the Mobilenet design. Idea from Galteri et al."""
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def residual_block(inputs, filters, block_id, expansion=6, stride=1, alpha=1.0):
"""Inverted Residual block that uses depth wise convolutions for parameter efficiency.
Args:
inputs: The input feature map.
filters: Number of filters in each convolution in the block.
block_id: An integer specifier for the id of the block in the graph.
expansion: Channel expansion factor.
stride: The stride of the convolution.
alpha: Depth expansion factor.
Returns:
x: The output of the inverted residual block.
"""
channel_axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
in_channels = keras.backend.int_shape(inputs)[channel_axis]
pointwise_conv_filters = int(filters * alpha)
pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
x = inputs
prefix = 'block_{}_'.format(block_id)
if block_id:
# Expand
x = keras.layers.Conv2D(expansion * in_channels,
kernel_size=1,
padding='same',
use_bias=True,
activation=None,
name=prefix + 'expand')(x)
x = keras.layers.BatchNormalization(axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + 'expand_BN')(x)
x = keras.layers.Activation('relu', name=prefix + 'expand_relu')(x)
else:
prefix = 'expanded_conv_'
# Depthwise
x = keras.layers.DepthwiseConv2D(kernel_size=3,
strides=stride,
activation=None,
use_bias=True,
padding='same' if stride == 1 else 'valid',
name=prefix + 'depthwise')(x)
x = keras.layers.BatchNormalization(axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + 'depthwise_BN')(x)
x = keras.layers.Activation('relu', name=prefix + 'depthwise_relu')(x)
# Project
x = keras.layers.Conv2D(pointwise_filters,
kernel_size=1,
padding='same',
use_bias=True,
activation=None,
name=prefix + 'project')(x)
x = keras.layers.BatchNormalization(axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + 'project_BN')(x)
if in_channels == pointwise_filters and stride == 1:
return keras.layers.Add(name=prefix + 'add')([inputs, x])
return x
def deconv2d(layer_input):
"""Upsampling layer to increase height and width of the input.
Uses PixelShuffle for upsampling.
Args:
layer_input: The input tensor to upsample.
Returns:
u: Upsampled input by a factor of 2.
"""
u = keras.layers.UpSampling2D(size=2, interpolation='bilinear')(layer_input)
u = keras.layers.Conv2D(self.gf, kernel_size=3, strides=1, padding='same')(u)
u = keras.layers.PReLU(shared_axes=[1, 2])(u)
return u
# Low resolution image input
img_lr = keras.Input(shape=self.lr_shape)
# Pre-residual block
c1 = keras.layers.Conv2D(self.gf, kernel_size=3, strides=1, padding='same')(img_lr)
c1 = keras.layers.BatchNormalization()(c1)
c1 = keras.layers.PReLU(shared_axes=[1, 2])(c1)
# Propogate through residual blocks
r = residual_block(c1, self.gf, 0)
for idx in range(1, self.n_residual_blocks):
r = residual_block(r, self.gf, idx)
# Post-residual block
c2 = keras.layers.Conv2D(self.gf, kernel_size=3, strides=1, padding='same')(r)
c2 = keras.layers.BatchNormalization()(c2)
c2 = keras.layers.Add()([c2, c1])
# Upsampling
u1 = deconv2d(c2)
u2 = deconv2d(u1)
# Generate high resolution output
gen_hr = keras.layers.Conv2D(3, kernel_size=3, strides=1, padding='same', activation='tanh')(u2)
return keras.models.Model(img_lr, gen_hr)
def build_discriminator(self):
"""Builds a discriminator network based on the SRGAN design."""
def d_block(layer_input, filters, strides=1, bn=True):
"""Discriminator layer block.
Args:
layer_input: Input feature map for the convolutional block.
filters: Number of filters in the convolution.
strides: The stride of the convolution.
bn: Whether to use batch norm or not.
"""
d = keras.layers.Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
if bn:
d = keras.layers.BatchNormalization(momentum=0.8)(d)
d = keras.layers.LeakyReLU(alpha=0.2)(d)
return d
# Input img
d0 = keras.layers.Input(shape=self.hr_shape)
d1 = d_block(d0, self.df, bn=False)
d2 = d_block(d1, self.df, strides=2)
d3 = d_block(d2, self.df)
d4 = d_block(d3, self.df, strides=2)
d5 = d_block(d4, self.df * 2)
d6 = d_block(d5, self.df * 2, strides=2)
d7 = d_block(d6, self.df * 2)
d8 = d_block(d7, self.df * 2, strides=2)
validity = keras.layers.Conv2D(1, kernel_size=1, strides=1, activation='sigmoid', padding='same')(d8)
return keras.models.Model(d0, validity)