Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 4084c5c

Browse files
Niki ParmarRyan Sepassi
authored andcommitted
Add modality for images that compresses pixels and can be used for generation tasks.
PiperOrigin-RevId: 175313670
1 parent eb5652f commit 4084c5c

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ def standardize_images(x):
151151
return x
152152

153153

154+
def convert_rgb_to_real(x):
155+
"""Conversion of pixel values to real numbers."""
156+
with tf.name_scope("rgb_to_real", [x]):
157+
x = tf.to_float(x)
158+
# Use the formula (value/128) - 1 to convert each channel value into a
159+
# real number in the range -1 to 1.
160+
x = (x /128) - 1
161+
return x
162+
163+
154164
def image_augmentation(images, do_colors=False):
155165
"""Image augmentation: cropping, flipping, and color transforms."""
156166
images = tf.random_crop(images, [299, 299, 3])

tensor2tensor/layers/modalities.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,76 @@ def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
229229
top_out, targets, weights_fn=weights_fn)
230230

231231

232+
@registry.register_image_modality("image_identity_compress")
233+
class ImageIdentityCompressModality(modality.Modality):
234+
"""Modality for images used in generation."""
235+
236+
@property
237+
def top_dimensionality(self):
238+
return 256
239+
240+
def bottom_compress(self, inputs, name="bottom"):
241+
"""Transform input from data space to model space.
242+
243+
Perform conversion of RGB pixel values to a real number and combine values
244+
for each pixel to form representation of image_length x image_length dims.
245+
246+
Args:
247+
inputs: A Tensor with shape [batch, ...]
248+
name: string, scope.
249+
Returns:
250+
body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
251+
"""
252+
with tf.variable_scope(name):
253+
inputs = common_layers.convert_rgb_to_real(inputs)
254+
ishape = tf.shape(inputs)
255+
inputs = tf.reshape(inputs, [-1, ishape[1], ishape[2]*ishape[3], 1])
256+
inputs.set_shape([None, None, None, 1])
257+
# We compress RGB intensities for each pixel using a conv.
258+
x = common_layers.conv_block(
259+
inputs,
260+
self._body_input_depth, [((1, 1), (1, 3))],
261+
first_relu=False,
262+
padding="VALID",
263+
strides=(1, 3),
264+
force2d=True,
265+
name="conv_input")
266+
return x
267+
268+
def bottom(self, inputs):
269+
return self.bottom_compress(inputs, "input_bottom")
270+
271+
def targets_bottom(self, inputs):
272+
return self.bottom_compress(inputs, "output_bottom")
273+
274+
def top(self, body_output, _):
275+
with tf.variable_scope(self.name):
276+
hidden_dim = self._model_hparams.hidden_size
277+
img_len = self._model_hparams.img_len
278+
channels = self._model_hparams.num_channels
279+
batch = tf.shape(body_output)[0]
280+
x = common_layers.conv(
281+
body_output,
282+
hidden_dim*channels, (1, 1),
283+
padding="VALID",
284+
activation=tf.nn.relu,
285+
name="decompress_conv")
286+
x = tf.reshape(x, [batch, img_len, img_len*channels, hidden_dim])
287+
x.set_shape([None, None, None, hidden_dim])
288+
x = common_layers.conv(x,
289+
self.top_dimensionality,
290+
(1, 1), name="output_conv")
291+
x = tf.reshape(x, [-1, img_len, img_len,
292+
channels, self.top_dimensionality])
293+
return x
294+
295+
def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
296+
# Call the default implementation, but weight 1.0 on 0s by default.
297+
# (Since we're processing images and so have no padding and some pixel 0s.)
298+
return super(ImageIdentityCompressModality, self).loss(
299+
top_out, targets, weights_fn=weights_fn)
300+
301+
232302
@registry.register_audio_modality("default")
233303
class AudioModality(modality.Modality):
234304
"""Performs strided conv compressions for audio data."""

0 commit comments

Comments
 (0)