@@ -143,52 +143,46 @@ def model_fn_body(self, features):
143
143
encoder_decoder_attention_bias ,
144
144
decoder_self_attention_bias , hparams )
145
145
146
- def _greedy_infer (self , features , decode_length , last_position_only = True ):
146
+ def _greedy_infer (self , features , decode_length ):
147
147
"""Fast version of greedy decoding.
148
148
149
149
Args:
150
150
features: an map of string to `Tensor`
151
151
decode_length: an integer. How many additional timesteps to decode.
152
- last_position_only: MUST be true for fast decoding!
153
152
154
153
Returns:
155
154
samples: [batch_size, input_length + decode_length]
156
155
logits: Not returned
157
156
losses: Not returned
158
157
159
158
Raises:
160
- ValueError: If last_position_only if False
161
159
NotImplementedError: If there are multiple data shards.
162
160
"""
163
- decoded_ids , _ = self ._fast_decode (
164
- features , decode_length , last_position_only )
161
+ decoded_ids , _ = self ._fast_decode (features , decode_length )
165
162
return decoded_ids , None , None
166
163
167
164
def _beam_decode (self , features , decode_length , beam_size , top_beams ,
168
- last_position_only , alpha ):
165
+ alpha ):
169
166
"""Beam search decoding.
170
167
171
168
Args:
172
169
features: an map of string to `Tensor`
173
170
decode_length: an integer. How many additional timesteps to decode.
174
171
beam_size: number of beams.
175
172
top_beams: an integer. How many of the beams to return.
176
- last_position_only: MUST be true for fast decoding!
177
173
alpha: Float that controls the length penalty. larger the alpha, stronger
178
174
the preference for slonger translations.
179
175
180
176
Returns:
181
177
samples: an integer `Tensor`. Top samples from the beam search
182
178
"""
183
179
decoded_ids , scores = self ._fast_decode (
184
- features , decode_length , last_position_only , beam_size , top_beams ,
185
- alpha )
180
+ features , decode_length , beam_size , top_beams , alpha )
186
181
return {"outputs" : decoded_ids , "scores" : scores }
187
182
188
183
def _fast_decode (self ,
189
184
features ,
190
185
decode_length ,
191
- last_position_only = True ,
192
186
beam_size = 1 ,
193
187
top_beams = 1 ,
194
188
alpha = 1.0 ):
@@ -200,7 +194,6 @@ def _fast_decode(self,
200
194
Args:
201
195
features: a map of string to model features.
202
196
decode_length: an integer. How many additional timesteps to decode.
203
- last_position_only: MUST be true for fast decoding!
204
197
beam_size: number of beams.
205
198
top_beams: an integer. How many of the beams to return.
206
199
alpha: Float that controls the length penalty. larger the alpha, stronger
@@ -210,11 +203,8 @@ def _fast_decode(self,
210
203
samples: an integer `Tensor`. Top samples from the beam search
211
204
212
205
Raises:
213
- ValueError: If last_position_only if False
214
206
NotImplementedError: If there are multiple data shards.
215
207
"""
216
- if not last_position_only :
217
- raise ValueError ("Fast decoding only deals with the last positions!" )
218
208
if self ._num_datashards != 1 :
219
209
raise NotImplementedError ("Fast decoding only supports a single shard." )
220
210
dp = self ._data_parallelism
0 commit comments