Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit fa46070

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Get rid of "last_position_only" by adding the corresponding property to Modality.
PiperOrigin-RevId: 175342179
1 parent 4084c5c commit fa46070

File tree

8 files changed

+52
-62
lines changed

8 files changed

+52
-62
lines changed

tensor2tensor/layers/modalities.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def name(self):
4949
def top_dimensionality(self):
5050
return self._vocab_size
5151

52+
@property
53+
def top_is_pointwise(self):
54+
return True
55+
5256
def _get_weights(self, hidden_dim=None):
5357
"""Create or get concatenated embedding or softmax variable.
5458

tensor2tensor/models/transformer.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -143,52 +143,46 @@ def model_fn_body(self, features):
143143
encoder_decoder_attention_bias,
144144
decoder_self_attention_bias, hparams)
145145

146-
def _greedy_infer(self, features, decode_length, last_position_only=True):
146+
def _greedy_infer(self, features, decode_length):
147147
"""Fast version of greedy decoding.
148148
149149
Args:
150150
features: an map of string to `Tensor`
151151
decode_length: an integer. How many additional timesteps to decode.
152-
last_position_only: MUST be true for fast decoding!
153152
154153
Returns:
155154
samples: [batch_size, input_length + decode_length]
156155
logits: Not returned
157156
losses: Not returned
158157
159158
Raises:
160-
ValueError: If last_position_only if False
161159
NotImplementedError: If there are multiple data shards.
162160
"""
163-
decoded_ids, _ = self._fast_decode(
164-
features, decode_length, last_position_only)
161+
decoded_ids, _ = self._fast_decode(features, decode_length)
165162
return decoded_ids, None, None
166163

167164
def _beam_decode(self, features, decode_length, beam_size, top_beams,
168-
last_position_only, alpha):
165+
alpha):
169166
"""Beam search decoding.
170167
171168
Args:
172169
features: an map of string to `Tensor`
173170
decode_length: an integer. How many additional timesteps to decode.
174171
beam_size: number of beams.
175172
top_beams: an integer. How many of the beams to return.
176-
last_position_only: MUST be true for fast decoding!
177173
alpha: Float that controls the length penalty. larger the alpha, stronger
178174
the preference for slonger translations.
179175
180176
Returns:
181177
samples: an integer `Tensor`. Top samples from the beam search
182178
"""
183179
decoded_ids, scores = self._fast_decode(
184-
features, decode_length, last_position_only, beam_size, top_beams,
185-
alpha)
180+
features, decode_length, beam_size, top_beams, alpha)
186181
return {"outputs": decoded_ids, "scores": scores}
187182

188183
def _fast_decode(self,
189184
features,
190185
decode_length,
191-
last_position_only=True,
192186
beam_size=1,
193187
top_beams=1,
194188
alpha=1.0):
@@ -200,7 +194,6 @@ def _fast_decode(self,
200194
Args:
201195
features: a map of string to model features.
202196
decode_length: an integer. How many additional timesteps to decode.
203-
last_position_only: MUST be true for fast decoding!
204197
beam_size: number of beams.
205198
top_beams: an integer. How many of the beams to return.
206199
alpha: Float that controls the length penalty. larger the alpha, stronger
@@ -210,11 +203,8 @@ def _fast_decode(self,
210203
samples: an integer `Tensor`. Top samples from the beam search
211204
212205
Raises:
213-
ValueError: If last_position_only if False
214206
NotImplementedError: If there are multiple data shards.
215207
"""
216-
if not last_position_only:
217-
raise ValueError("Fast decoding only deals with the last positions!")
218208
if self._num_datashards != 1:
219209
raise NotImplementedError("Fast decoding only supports a single shard.")
220210
dp = self._data_parallelism

tensor2tensor/models/transformer_adv.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def model_fn_body(self, features):
166166
features["target_space_id"], self._hparams)
167167

168168
def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
169-
last_position_only=False, alpha=0.0):
169+
alpha=0.0):
170170
"""Produce predictions from the model."""
171171
if not features:
172172
features = {}
@@ -184,8 +184,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
184184
initial_output = tf.zeros((batch_size, 2 * length, 1, 1), dtype=tf.int64)
185185

186186
features["targets"] = initial_output
187-
sharded_logits, _ = self.model_fn(
188-
features, False, last_position_only=last_position_only)
187+
sharded_logits, _ = self.model_fn(features, False)
189188
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
190189
samples = tf.concat(sharded_samples, 0)
191190

@@ -194,8 +193,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
194193
for _ in xrange(how_many_more_steps):
195194
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
196195
features["targets"] = samples
197-
sharded_logits, _ = self.model_fn(
198-
features, False, last_position_only=last_position_only)
196+
sharded_logits, _ = self.model_fn(features, False)
199197
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
200198
samples = tf.concat(sharded_samples, 0)
201199

tensor2tensor/models/transformer_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ def getModel(self, hparams, mode=tf.estimator.ModeKeys.TRAIN):
5656
"target_space_id": tf.constant(1, dtype=tf.int32),
5757
}
5858

59-
return transformer.Transformer(
60-
hparams, tf.estimator.ModeKeys.PREDICT, p_hparams), features
59+
return transformer.Transformer(hparams, mode, p_hparams), features
6160

6261
def testTransformer(self):
6362
model, features = self.getModel(transformer.transformer_small())
@@ -99,8 +98,7 @@ def testGreedyVsFast(self):
9998
mode=tf.estimator.ModeKeys.PREDICT)
10099

101100
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
102-
greedy_result, _, _ = model._slow_greedy_infer(
103-
features, decode_length, last_position_only=True)
101+
greedy_result, _, _ = model._slow_greedy_infer(features, decode_length)
104102
greedy_result = tf.squeeze(greedy_result, axis=[2, 3])
105103

106104
fast_result, _, _ = model._greedy_infer(features, decode_length)
@@ -139,15 +137,13 @@ def testBeamVsFast(self):
139137
decode_length,
140138
beam_size=4,
141139
top_beams=1,
142-
last_position_only=True,
143140
alpha=1.0)["outputs"]
144141

145142
fast_result = model._beam_decode(
146143
features,
147144
decode_length,
148145
beam_size=4,
149146
top_beams=1,
150-
last_position_only=True,
151147
alpha=1.0)["outputs"]
152148

153149
with self.test_session():

tensor2tensor/utils/decoding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
def decode_hparams(overrides=""):
4343
"""Hyperparameters for decoding."""
4444
hp = tf.contrib.training.HParams(
45-
use_last_position_only=False,
4645
save_images=False,
4746
problem_idx=0,
4847
extra_length=50,

tensor2tensor/utils/modality.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,22 @@ def top_dimensionality(self):
7171
def _body_input_depth(self):
7272
return self._model_hparams.hidden_size
7373

74+
@property
75+
def top_is_pointwise(self):
76+
"""Whether the top mapping of the modality is pointwise.
77+
78+
An example of a pointwise top mapping is a linear layer followed by
79+
a softmax. Given a tensor [batch, length, height, depth] it operates
80+
only on the last axis, on every point in [batch, length, height] fully
81+
independently. In contrast, a classifier that first averages over length
82+
and height is not pointwise, as it depends on the whole field. It is useful
83+
to know if a top is pointwise to speed up decoding in certain models.
84+
85+
Returns:
86+
A Boolean, True if the modality is pointwise, False otherwise (default).
87+
"""
88+
return False
89+
7490
def bottom(self, x):
7591
"""Transform one shard of input.
7692

tensor2tensor/utils/model_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def nth_model(n):
115115
features,
116116
beam_size=decode_hp.beam_size,
117117
top_beams=(decode_hp.beam_size if decode_hp.return_beams else 1),
118-
last_position_only=decode_hp.use_last_position_only,
119118
alpha=decode_hp.alpha,
120119
decode_length=decode_hp.extra_length)
121120
# In distributed mode, we build graph for problem=0 and problem=worker_id.

0 commit comments

Comments
 (0)