Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit aac632f

Browse files
author
Ryan Sepassi
committed
Batch norm fix, --use_tpu, and Resnet50 model
PiperOrigin-RevId: 175353643
1 parent fa46070 commit aac632f

14 files changed

+441
-60
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,12 +1323,20 @@ def masked_local_attention_1d(
13231323
with tf.variable_scope(name, default_name="local_attention_1d",
13241324
values=[q, k, v]):
13251325
v_shape = v.get_shape()
1326-
batch = tf.shape(q)[0]
1327-
heads = tf.shape(q)[1]
1328-
length = tf.shape(q)[2]
1326+
batch = common_layers.shape_dim(q, 0)
1327+
heads = common_layers.shape_dim(q, 1)
1328+
length = common_layers.shape_dim(q, 2)
1329+
if isinstance(block_length, tf.Tensor):
1330+
const = tf.contrib.util.constant_value(block_length)
1331+
if const is not None:
1332+
block_length = int(const)
1333+
13291334
# If (length < 2 * block_length), then we use only one block.
1330-
block_length = tf.where(tf.less(length, block_length * 2),
1331-
length, block_length)
1335+
if isinstance(length, int) and isinstance(block_length, int):
1336+
block_length = length if length < block_length * 2 else block_length
1337+
else:
1338+
block_length = tf.where(tf.less(length, block_length * 2),
1339+
length, block_length)
13321340
depth_k = tf.shape(k)[3]
13331341
depth_v = tf.shape(v)[3]
13341342
original_length = length

tensor2tensor/layers/modalities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,8 @@ def bottom(self, x):
414414

415415
def targets_bottom(self, x):
416416
with tf.variable_scope(self.name):
417-
return tf.zeros([tf.shape(x)[0], 1, 1, self._body_input_depth])
417+
return tf.zeros(
418+
[common_layers.shape_dim(x, 0), 1, 1, self._body_input_depth])
418419

419420
def top(self, body_output, _):
420421
"""Transform inputs from model space to target space.

tensor2tensor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tensor2tensor.models import lstm
3434
from tensor2tensor.models import multimodel
3535
from tensor2tensor.models import neural_gpu
36+
from tensor2tensor.models import resnet
3637
from tensor2tensor.models import shake_shake
3738
from tensor2tensor.models import slicenet
3839
from tensor2tensor.models import transformer

tensor2tensor/models/resnet.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Resnets."""
17+
# Copied from cloud_tpu/models/resnet_garden and modified
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
# Dependency imports
24+
25+
from tensor2tensor.layers import common_hparams
26+
from tensor2tensor.utils import registry
27+
from tensor2tensor.utils import t2t_model
28+
29+
import tensorflow as tf
30+
31+
# TODO(rsepassi): make hparams
32+
_BATCH_NORM_DECAY = 0.997
33+
_BATCH_NORM_EPSILON = 1e-5
34+
35+
36+
def bottleneck_block(inputs, filters, is_training, projection_shortcut, strides,
37+
data_format):
38+
"""Bottleneck block variant for residual networks with BN before convolutions.
39+
40+
Args:
41+
inputs: A tensor of size [batch, channels, height, width].
42+
filters: The number of filters for the first two convolutions. Note that the
43+
third and final convolution will use 4 times as many filters.
44+
is_training: A Boolean for whether the model is in training or inference
45+
mode. Needed for batch normalization.
46+
projection_shortcut: The function to use for projection shortcuts (typically
47+
a 1x1 convolution when downsampling the input).
48+
strides: The block's stride. If greater than 1, this block will ultimately
49+
downsample the input.
50+
data_format: channels_{first, last}
51+
52+
Returns:
53+
The output tensor of the block.
54+
"""
55+
shortcut = inputs
56+
out = inputs
57+
out = batch_norm_relu(out, is_training, data_format)
58+
59+
# The projection shortcut should come after the first batch norm and ReLU
60+
# since it performs a 1x1 convolution.
61+
if projection_shortcut is not None:
62+
shortcut = projection_shortcut(out)
63+
64+
do_bn_relus = [False, True, True]
65+
kernel_sizes = [1, 3, 1]
66+
layer_strides = [1, strides, 1]
67+
filter_sizes = [filters, filters, 4 * filters]
68+
69+
for do_bn_relu, kernel_size, layer_stride, filter_size in zip(
70+
do_bn_relus, kernel_sizes, layer_strides, filter_sizes):
71+
if do_bn_relu:
72+
out = batch_norm_relu(out, is_training, data_format)
73+
out = conv2d_fixed_padding(
74+
inputs=out,
75+
filters=filter_size,
76+
kernel_size=kernel_size,
77+
strides=layer_stride,
78+
data_format=data_format)
79+
80+
return out + shortcut
81+
82+
83+
def batch_norm_relu(inputs, is_training, data_format):
84+
"""Performs a batch normalization followed by a ReLU."""
85+
# We set fused=True for a significant performance boost.
86+
out = tf.layers.batch_normalization(
87+
inputs=inputs,
88+
axis=1 if data_format == "channels_first" else 3,
89+
momentum=_BATCH_NORM_DECAY,
90+
epsilon=_BATCH_NORM_EPSILON,
91+
center=True,
92+
scale=True,
93+
training=is_training,
94+
fused=True)
95+
out = tf.nn.relu(out)
96+
return out
97+
98+
99+
def block_layer(inputs, filters, block_fn, blocks, strides, is_training,
100+
data_format, name):
101+
"""Creates one layer of blocks for the ResNet model.
102+
103+
Args:
104+
inputs: A tensor of size [batch, channels, height, width].
105+
filters: The number of filters for the first convolution of the layer.
106+
block_fn: The block to use within the model, either `building_block` or
107+
`bottleneck_block`.
108+
blocks: The number of blocks contained in the layer.
109+
strides: The stride to use for the first convolution of the layer. If
110+
greater than 1, this layer will ultimately downsample the input.
111+
is_training: Either True or False, whether we are currently training the
112+
model. Needed for batch norm.
113+
data_format: channels_{first, last}
114+
name: A string name for the tensor output of the block layer.
115+
116+
Returns:
117+
The output tensor of the block layer.
118+
"""
119+
# Bottleneck blocks end with 4x the number of filters as they start with
120+
filters_out = 4 * filters if block_fn is bottleneck_block else filters
121+
122+
def projection_shortcut(inputs):
123+
return conv2d_fixed_padding(
124+
inputs=inputs,
125+
filters=filters_out,
126+
kernel_size=1,
127+
strides=strides,
128+
data_format=data_format)
129+
130+
# Only the first block per block_layer uses projection_shortcut and strides
131+
inputs = block_fn(inputs, filters, is_training, projection_shortcut, strides,
132+
data_format)
133+
134+
for _ in range(1, blocks):
135+
inputs = block_fn(inputs, filters, is_training, None, 1, data_format)
136+
137+
return tf.identity(inputs, name)
138+
139+
140+
def fixed_padding(inputs, kernel_size, data_format):
141+
"""Pads the input along the spatial dimensions independently of input size.
142+
143+
Args:
144+
inputs: A 4D tensor layed out according to data_format
145+
kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
146+
Should be a positive integer.
147+
data_format: channels_{first, last}
148+
149+
Returns:
150+
A tensor of size [batch, channels, height_out, width_out] with the
151+
input either intact (if kernel_size == 1) or padded (if kernel_size > 1).
152+
"""
153+
pad_total = kernel_size - 1
154+
pad_beg = pad_total // 2
155+
pad_end = pad_total - pad_beg
156+
spatial_pads = [[pad_beg, pad_end], [pad_beg, pad_end]]
157+
if data_format == "channels_first":
158+
pads = [[0, 0], [0, 0]] + spatial_pads
159+
else:
160+
assert data_format == "channels_last"
161+
pads = [[0, 0]] + spatial_pads + [[0, 0]]
162+
padded_inputs = tf.pad(inputs, pads)
163+
return padded_inputs
164+
165+
166+
def conv2d_fixed_padding(**kwargs):
167+
"""conv2d with fixed_padding, based only on kernel_size."""
168+
strides = kwargs["strides"]
169+
if strides > 1:
170+
kwargs["inputs"] = fixed_padding(kwargs["inputs"], kwargs["kernel_size"],
171+
kwargs["data_format"])
172+
173+
defaults = {
174+
"padding": ("SAME" if strides == 1 else "VALID"),
175+
"use_bias": False,
176+
"kernel_initializer": tf.variance_scaling_initializer(),
177+
}
178+
defaults.update(kwargs)
179+
180+
return tf.layers.conv2d(**defaults)
181+
182+
183+
def resnet50(inputs, hparams):
184+
"""Resnet50."""
185+
is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
186+
block_fn = bottleneck_block
187+
188+
out = inputs
189+
data_format = "channels_first" if hparams.use_nchw else "channels_last"
190+
if hparams.use_nchw:
191+
# Convert from channels_last (NHWC) to channels_first (NCHW). This provides
192+
# a large performance boost on GPU.
193+
out = tf.transpose(inputs, [0, 3, 1, 2])
194+
195+
out = conv2d_fixed_padding(
196+
inputs=out, filters=64, kernel_size=7, strides=2, data_format=data_format)
197+
out = tf.identity(out, "initial_conv")
198+
out = tf.layers.max_pooling2d(
199+
inputs=out,
200+
pool_size=3,
201+
strides=2,
202+
padding="SAME",
203+
data_format=data_format)
204+
out = tf.identity(out, "initial_max_pool")
205+
206+
for i, (num_filters, stride, block_size) in enumerate(
207+
zip(hparams.num_filters, hparams.strides, hparams.layer_sizes)):
208+
out = block_layer(
209+
inputs=out,
210+
filters=num_filters,
211+
block_fn=block_fn,
212+
blocks=block_size,
213+
strides=stride,
214+
is_training=is_training,
215+
data_format=data_format,
216+
name="block_layer_%d" % i)
217+
218+
out = batch_norm_relu(out, is_training, data_format)
219+
out = tf.layers.average_pooling2d(
220+
inputs=out,
221+
pool_size=7,
222+
strides=1,
223+
padding="VALID",
224+
data_format=data_format)
225+
out = tf.identity(out, "final_avg_pool")
226+
227+
if hparams.use_nchw:
228+
# Back to NHWC
229+
out = tf.transpose(out, [0, 2, 3, 1])
230+
return out
231+
232+
233+
@registry.register_model
234+
class Resnet50(t2t_model.T2TModel):
235+
236+
def model_fn_body(self, features):
237+
return resnet50(features["inputs"], self.hparams)
238+
239+
240+
@registry.register_hparams
241+
def resnet_base():
242+
"""Set of hyperparameters."""
243+
hparams = common_hparams.basic_params1()
244+
hparams.add_hparam("layer_sizes", [3, 4, 6, 3])
245+
hparams.add_hparam("use_nchw", True)
246+
hparams.add_hparam("num_filters", [64, 128, 256, 512])
247+
hparams.add_hparam("strides", [1, 2, 2, 2])
248+
hparams.tpu_batch_size_per_shard = 48
249+
return hparams

tensor2tensor/models/resnet_test.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Resnet tests."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
# Dependency imports
23+
24+
import numpy as np
25+
26+
from tensor2tensor.data_generators import problem_hparams
27+
from tensor2tensor.models import resnet
28+
from tensor2tensor.utils import registry
29+
30+
import tensorflow as tf
31+
32+
33+
def resnet_tiny_cpu():
34+
hparams = resnet.resnet_base()
35+
hparams.layer_sizes = [2, 2, 2, 2]
36+
hparams.num_filters = [10, 20, 30, 40]
37+
hparams.use_nchw = False
38+
return hparams
39+
40+
41+
class ResnetTest(tf.test.TestCase):
42+
43+
def _testResnet(self, img_size, output_size):
44+
vocab_size = 9
45+
batch_size = 2
46+
x = np.random.random_integers(
47+
0, high=255, size=(batch_size, img_size, img_size, 3))
48+
y = np.random.random_integers(
49+
1, high=vocab_size - 1, size=(batch_size, 1, 1, 1))
50+
hparams = resnet_tiny_cpu()
51+
p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size)
52+
p_hparams.input_modality["inputs"] = (registry.Modalities.IMAGE, None)
53+
with self.test_session() as session:
54+
features = {
55+
"inputs": tf.constant(x, dtype=tf.int32),
56+
"targets": tf.constant(y, dtype=tf.int32),
57+
}
58+
model = resnet.Resnet50(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
59+
sharded_logits, _ = model.model_fn(features)
60+
logits = tf.concat(sharded_logits, 0)
61+
session.run(tf.global_variables_initializer())
62+
res = session.run(logits)
63+
self.assertEqual(res.shape, (batch_size,) + output_size + (1, vocab_size))
64+
65+
def testResnetLarge(self):
66+
self._testResnet(img_size=299, output_size=(4, 4))
67+
68+
69+
if __name__ == "__main__":
70+
tf.test.main()

tensor2tensor/models/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,7 @@ def update_hparams_for_tpu(hparams):
10841084
# Inputs
10851085
# Each example in the batch will be of (padded) length hparams.max_length
10861086
hparams.max_length = 64
1087-
hparams.tpu_batch_size_per_shard = 16
1087+
hparams.tpu_batch_size_per_shard = 20
10881088

10891089

10901090
@registry.register_hparams

tensor2tensor/tpu/tpu_trainer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
4242
flags.DEFINE_integer("iterations_per_loop", 1000,
4343
"Number of iterations in a TPU training loop.")
44+
flags.DEFINE_bool("use_tpu", True, "Whether to use TPU.")
4445

4546
# To maintain compatibility with some internal libs, we guard against these flag
4647
# definitions possibly erroring. Apologies for the ugliness.
@@ -68,9 +69,14 @@ def create_hparams():
6869

6970

7071
def create_experiment_fn():
71-
return lib.make_experiment_fn(FLAGS.model, get_problem_name(), FLAGS.data_dir,
72-
FLAGS.train_steps, FLAGS.eval_steps,
73-
FLAGS.local_eval_frequency)
72+
return lib.make_experiment_fn(
73+
FLAGS.model,
74+
get_problem_name(),
75+
FLAGS.data_dir,
76+
FLAGS.train_steps,
77+
FLAGS.eval_steps,
78+
FLAGS.local_eval_frequency,
79+
use_tpu=FLAGS.use_tpu)
7480

7581

7682
def create_run_config():

0 commit comments

Comments
 (0)