Skip to content

Commit 0c3ccc6

Browse files
committed
Clamping beam_idxs, still fails trying to cat past_states to hidden_states
1 parent 0899ad8 commit 0c3ccc6

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torchtext/prototype/generate.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -298,25 +298,26 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
298298

299299
# We could store this in model_kwargs
300300
num_hyps_in_prev_step = model_kwargs["past"][0][0].shape[0]
301-
301+
302302
num_finished_hyps_in_step = num_hyps_in_prev_step - len(prev_step_hyp_idxs)
303303
if num_finished_hyps_in_step > 0:
304304
beam_idxs = F.pad(beam_idxs, (0, num_finished_hyps_in_step), "constant", 0)
305-
305+
306+
beam_idxs = torch.clamp(beam_idxs, max=len(prev_step_hyp_idxs) - 1)
307+
306308
reordered_cached = self.model._reorder_cache(model_kwargs["past"], beam_idxs)
307309

308310
if num_finished_hyps_in_step > 0:
309311
sliced_cache = ()
310312
for states in reordered_cached:
311313
sliced_state = ()
312314
for state in states:
313-
sliced_state = sliced_state + (state[:len(prev_step_hyp_idxs)],)
315+
sliced_state = sliced_state + (state[: len(prev_step_hyp_idxs)],)
314316
sliced_cache = sliced_cache + (sliced_state,)
315317
reordered_cached = sliced_cache
316318

317319
model_inputs["past_key_values"] = reordered_cached
318320

319-
320321
# Forward pass
321322
outputs = self.model(**model_inputs)
322323

0 commit comments

Comments
 (0)