@@ -229,6 +229,76 @@ def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
229
229
top_out , targets , weights_fn = weights_fn )
230
230
231
231
232
+ @registry .register_image_modality ("image_identity_compress" )
233
+ class ImageIdentityCompressModality (modality .Modality ):
234
+ """Modality for images used in generation."""
235
+
236
+ @property
237
+ def top_dimensionality (self ):
238
+ return 256
239
+
240
+ def bottom_compress (self , inputs , name = "bottom" ):
241
+ """Transform input from data space to model space.
242
+
243
+ Perform conversion of RGB pixel values to a real number and combine values
244
+ for each pixel to form representation of image_length x image_length dims.
245
+
246
+ Args:
247
+ inputs: A Tensor with shape [batch, ...]
248
+ name: string, scope.
249
+ Returns:
250
+ body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
251
+ """
252
+ with tf .variable_scope (name ):
253
+ inputs = common_layers .convert_rgb_to_real (inputs )
254
+ ishape = tf .shape (inputs )
255
+ inputs = tf .reshape (inputs , [- 1 , ishape [1 ], ishape [2 ]* ishape [3 ], 1 ])
256
+ inputs .set_shape ([None , None , None , 1 ])
257
+ # We compress RGB intensities for each pixel using a conv.
258
+ x = common_layers .conv_block (
259
+ inputs ,
260
+ self ._body_input_depth , [((1 , 1 ), (1 , 3 ))],
261
+ first_relu = False ,
262
+ padding = "VALID" ,
263
+ strides = (1 , 3 ),
264
+ force2d = True ,
265
+ name = "conv_input" )
266
+ return x
267
+
268
+ def bottom (self , inputs ):
269
+ return self .bottom_compress (inputs , "input_bottom" )
270
+
271
+ def targets_bottom (self , inputs ):
272
+ return self .bottom_compress (inputs , "output_bottom" )
273
+
274
+ def top (self , body_output , _ ):
275
+ with tf .variable_scope (self .name ):
276
+ hidden_dim = self ._model_hparams .hidden_size
277
+ img_len = self ._model_hparams .img_len
278
+ channels = self ._model_hparams .num_channels
279
+ batch = tf .shape (body_output )[0 ]
280
+ x = common_layers .conv (
281
+ body_output ,
282
+ hidden_dim * channels , (1 , 1 ),
283
+ padding = "VALID" ,
284
+ activation = tf .nn .relu ,
285
+ name = "decompress_conv" )
286
+ x = tf .reshape (x , [batch , img_len , img_len * channels , hidden_dim ])
287
+ x .set_shape ([None , None , None , hidden_dim ])
288
+ x = common_layers .conv (x ,
289
+ self .top_dimensionality ,
290
+ (1 , 1 ), name = "output_conv" )
291
+ x = tf .reshape (x , [- 1 , img_len , img_len ,
292
+ channels , self .top_dimensionality ])
293
+ return x
294
+
295
+ def loss (self , top_out , targets , weights_fn = common_layers .weights_all ):
296
+ # Call the default implementation, but weight 1.0 on 0s by default.
297
+ # (Since we're processing images and so have no padding and some pixel 0s.)
298
+ return super (ImageIdentityCompressModality , self ).loss (
299
+ top_out , targets , weights_fn = weights_fn )
300
+
301
+
232
302
@registry .register_audio_modality ("default" )
233
303
class AudioModality (modality .Modality ):
234
304
"""Performs strided conv compressions for audio data."""
0 commit comments