Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8bdecbe

Browse files
authored
Merge pull request #646 from rsepassi/push
v1.5.5
2 parents af82068 + 688f4d5 commit 8bdecbe

32 files changed

+883
-280
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ matrix:
1616
- python: "3.6"
1717
env: TF_VERSION="1.4.*"
1818
- python: "3.6"
19-
env: TF_VERSION="1.6.*"
19+
env: TF_VERSION="1.5.*"
2020
before_install:
2121
- echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
2222
- curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -

.github/ISSUE_TEMPLATE.md renamed to ISSUE_TEMPLATE.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
### *TensorFlow* and *tensor2tensor* versions
88

9-
<!-- **Note** Run `pip list | grep tensor` to include TensorFlow and tensor2tensor versions -->
9+
<!-- **Note** Run `pip freeze | grep tensor` to get versions -->
1010

1111
>
1212
@@ -16,7 +16,7 @@
1616
1717
### In case of bug report: Error log
1818

19-
<!-- Please use code markdown to format output messages. -->
19+
<!-- Please use code markdown (```) to format output messages. -->
2020
<!-- See https://help.github.com/articles/creating-and-highlighting-code-blocks/ -->
2121

2222
>

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
1515
of deep learning models and datasets designed to make deep learning more
1616
accessible and [accelerate ML
1717
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).
18-
is actively used and maintained by researchers and engineers within the
18+
T2T is actively used and maintained by researchers and engineers within the
1919
[Google Brain team](https://research.google.com/teams/brain/) and a community
2020
of users. We're eager to collaborate with you too, so feel free to
2121
[open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues)

docs/cloud_tpu.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ See the official tutorial for [running Transfomer
1818
on Cloud TPUs](https://cloud.google.com/tpu/docs/tutorials/transformer)
1919
for some examples and try out your own problems.
2020

21+
Image Transformer:
22+
* `imagetransformer` with `imagetransformer_base_tpu` (or
23+
`imagetransformer_tiny_tpu`)
24+
* `img2img_transformer` with `img2img_transformer_base_tpu` (or
25+
`img2img_transformer_tiny_tpu`)
26+
27+
You can run the `ImageTransformer` model on problems like unconditional or
28+
conditional Image generation and `Img2ImgTransformer` model on Super Resolution.
29+
We run on datasets like CelebA, CIFAR and ImageNet but they should work with any
30+
other image dataset.
31+
2132
Residual networks:
2233
* `resnet` with `resnet_50` (or `resnet_18` or `resnet_34`)
2334
* `revnet` with `revnet_104` (or `revnet_38_cifar`)

docs/walkthrough.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
1515
of deep learning models and datasets designed to make deep learning more
1616
accessible and [accelerate ML
1717
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).
18-
is actively used and maintained by researchers and engineers within the
18+
T2T is actively used and maintained by researchers and engineers within the
1919
[Google Brain team](https://research.google.com/teams/brain/) and a community
2020
of users. We're eager to collaborate with you too, so feel free to
2121
[open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues)
@@ -154,7 +154,7 @@ For all translation problems, we suggest to try the Transformer model:
154154
this should reach a BLEU score of about 28 on the English-German data-set,
155155
which is close to state-of-the art. If training on a single GPU, try the
156156
`--hparams_set=transformer_base_single_gpu` setting. For very good results
157-
or larger data-sets (e.g., for English-French)m, try the big model
157+
or larger data-sets (e.g., for English-French), try the big model
158158
with `--hparams_set=transformer_big`.
159159

160160
## Basics

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.5.4',
8+
version='1.5.5',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/data_generators/cifar.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ def preprocess_example(self, example, mode, unused_hparams):
124124
image.set_shape([_CIFAR10_IMAGE_SIZE, _CIFAR10_IMAGE_SIZE, 3])
125125
if mode == tf.estimator.ModeKeys.TRAIN:
126126
image = image_utils.cifar_image_augmentation(image)
127-
image = tf.image.per_image_standardization(image)
127+
if not self._was_reversed:
128+
image = tf.image.per_image_standardization(image)
128129
example["inputs"] = image
129130
return example
130131

@@ -151,7 +152,8 @@ class ImageCifar10Plain(ImageCifar10):
151152
def preprocess_example(self, example, mode, unused_hparams):
152153
image = example["inputs"]
153154
image.set_shape([_CIFAR10_IMAGE_SIZE, _CIFAR10_IMAGE_SIZE, 3])
154-
image = tf.image.per_image_standardization(image)
155+
if not self._was_reversed:
156+
image = tf.image.per_image_standardization(image)
155157
example["inputs"] = image
156158
return example
157159

@@ -179,7 +181,8 @@ def dataset_filename(self):
179181
def preprocess_example(self, example, mode, unused_hparams):
180182
image = example["inputs"]
181183
image = image_utils.resize_by_area(image, 8)
182-
image = tf.image.per_image_standardization(image)
184+
if not self._was_reversed:
185+
image = tf.image.per_image_standardization(image)
183186
example["inputs"] = image
184187
return example
185188

@@ -192,7 +195,6 @@ def dataset_filename(self):
192195
return "image_cifar10_plain" # Reuse CIFAR-10 plain data.
193196

194197
def preprocess_example(self, example, unused_mode, unused_hparams):
195-
196198
inputs = example["inputs"]
197199
# For Img2Img resize input and output images as desired.
198200
example["inputs"] = image_utils.resize_by_area(inputs, 8)
@@ -330,7 +332,8 @@ def preprocess_example(self, example, mode, unused_hparams):
330332
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
331333
if mode == tf.estimator.ModeKeys.TRAIN:
332334
image = image_utils.cifar_image_augmentation(image)
333-
image = tf.image.per_image_standardization(image)
335+
if not self._was_reversed:
336+
image = tf.image.per_image_standardization(image)
334337
example["inputs"] = image
335338
return example
336339

@@ -357,7 +360,8 @@ class ImageCifar100Plain(ImageCifar100):
357360
def preprocess_example(self, example, mode, unused_hparams):
358361
image = example["inputs"]
359362
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
360-
image = tf.image.per_image_standardization(image)
363+
if not self._was_reversed:
364+
image = tf.image.per_image_standardization(image)
361365
example["inputs"] = image
362366
return example
363367

@@ -385,7 +389,8 @@ def dataset_filename(self):
385389
def preprocess_example(self, example, mode, unused_hparams):
386390
image = example["inputs"]
387391
image = image_utils.resize_by_area(image, 8)
388-
image = tf.image.per_image_standardization(image)
392+
if not self._was_reversed:
393+
image = tf.image.per_image_standardization(image)
389394
example["inputs"] = image
390395
return example
391396

@@ -398,7 +403,6 @@ def dataset_filename(self):
398403
return "image_cifar100_plain" # Reuse CIFAR-100 plain data.
399404

400405
def preprocess_example(self, example, unused_mode, unused_hparams):
401-
402406
inputs = example["inputs"]
403407
# For Img2Img resize input and output images as desired.
404408
example["inputs"] = image_utils.resize_by_area(inputs, 8)

tensor2tensor/data_generators/gym.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import tensorflow as tf
3636

3737

38+
39+
3840
flags = tf.flags
3941
FLAGS = flags.FLAGS
4042

@@ -157,7 +159,6 @@ def num_steps(self):
157159
return 5000
158160

159161

160-
161162
@registry.register_problem
162163
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
163164
"""Pong game, loaded actions."""
@@ -197,7 +198,7 @@ def generator(self, data_dir, tmp_dir):
197198
model_saver.restore(sess, FLAGS.model_path)
198199
for item in super(GymPongTrajectoriesFromPolicy,
199200
self).generator(data_dir, tmp_dir):
200-
yield item
201+
yield item
201202

202203
# TODO(blazej0): For training of atari agents wrappers are usually used.
203204
# Below we have a hacky solution which is a workaround to be used together

tensor2tensor/data_generators/image_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensor2tensor.data_generators import generator_utils
2727
from tensor2tensor.data_generators import problem
2828
from tensor2tensor.data_generators import text_encoder
29+
from tensor2tensor.utils import metrics
2930
from tensor2tensor.utils import registry
3031

3132
import tensorflow as tf
@@ -64,9 +65,19 @@ def example_reading_spec(self, label_repr=None):
6465
return data_fields, data_items_to_decoders
6566

6667
def preprocess_example(self, example, mode, hparams):
67-
example["inputs"] = tf.image.per_image_standardization(example["inputs"])
68+
if not self._was_reversed:
69+
example["inputs"] = tf.image.per_image_standardization(example["inputs"])
6870
return example
6971

72+
def eval_metrics(self):
73+
eval_metrics = [
74+
metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5,
75+
metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.NEG_LOG_PERPLEXITY
76+
]
77+
if self._was_reversed:
78+
eval_metrics += [metrics.Metrics.IMAGE_SUMMARY]
79+
return eval_metrics
80+
7081

7182
class Image2ClassProblem(ImageProblem):
7283
"""Base class for image classification problems."""

tensor2tensor/data_generators/imagenet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ def distorted_bounding_box_crop(image,
334334
Returns:
335335
(cropped image `Tensor`, distorted bbox `Tensor`).
336336
"""
337-
with tf.name_scope(scope, default_name="distorted_bounding_box_crop", values=[image, bbox]):
337+
with tf.name_scope(scope, default_name="distorted_bounding_box_crop",
338+
values=[image, bbox]):
338339
# Each bounding box has shape [1, num_boxes, box coords] and
339340
# the coordinates are ordered [ymin, xmin, ymax, xmax].
340341

0 commit comments

Comments
 (0)