diff --git a/AnnotatedTransformer.ipynb b/AnnotatedTransformer.ipynb index 0f7da7d..f517d14 100644 --- a/AnnotatedTransformer.ipynb +++ b/AnnotatedTransformer.ipynb @@ -2497,7 +2497,7 @@ "def greedy_decode(model, src, src_mask, max_len, start_symbol):\n", " memory = model.encode(src, src_mask)\n", " ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)\n", - " for i in range(max_len - 1):\n", + " for i in range(max_len):\n", " out = model.decode(\n", " memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)\n", " )\n", @@ -2507,7 +2507,9 @@ " ys = torch.cat(\n", " [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1\n", " )\n", - " return ys" + " \n", + " # Return the target sequence without the start symbol\n", + " return ys[:, 1:]" ] }, { diff --git a/the_annotated_transformer.py b/the_annotated_transformer.py index 4aa1d46..f31cd3f 100644 --- a/the_annotated_transformer.py +++ b/the_annotated_transformer.py @@ -1313,7 +1313,7 @@ def __call__(self, x, y, norm): def greedy_decode(model, src, src_mask, max_len, start_symbol): memory = model.encode(src, src_mask) ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data) - for i in range(max_len - 1): + for i in range(max_len): out = model.decode( memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data) ) @@ -1323,7 +1323,9 @@ def greedy_decode(model, src, src_mask, max_len, start_symbol): ys = torch.cat( [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1 ) - return ys + + # Return the target sequence without the start symbol + return ys[:, 1:] # %% id="qgIZ2yEtdYwe" tags=[]