Skip to content

Commit

Permalink
left padding for LM inference
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed Nov 22, 2023
1 parent 3d4c8de commit 802ec8c
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 62 deletions.
3 changes: 1 addition & 2 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,7 @@ def _forward(
dec_mask = None

if layer_in.size(1) > 1:
# masking is necessary when sequence length is greater than one
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
dec_mask = tgt_pad_mask
dec_mask = dec_mask.unsqueeze(1)
dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
# mask now are (batch x 1 x tlen x tlen)
Expand Down
2 changes: 2 additions & 0 deletions onmt/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def infer_file(self):
self.transforms_cls,
self.vocabs,
task=CorpusTask.INFER,
model_task=self.translator.model_task,
device_id=self.device_id,
)
scores, preds = self._translate(infer_iter)
Expand All @@ -45,6 +46,7 @@ def infer_list(self, src):
self.transforms_cls,
self.vocabs,
task=CorpusTask.INFER,
model_task=self.translator.model_task,
src=src,
device_id=self.device_id,
)
Expand Down
14 changes: 11 additions & 3 deletions onmt/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module that contain iterator used for dynamic data."""
import torch
from itertools import cycle
from onmt.constants import CorpusTask
from onmt.constants import CorpusTask, ModelTask
from onmt.inputters.text_corpus import get_corpora, build_corpora_iters
from onmt.inputters.text_utils import (
text_sort_key,
Expand Down Expand Up @@ -126,6 +126,7 @@ def __init__(
transforms,
vocabs,
task,
model_task,
batch_type,
batch_size,
batch_size_multiple,
Expand Down Expand Up @@ -164,10 +165,14 @@ def __init__(
self.skip_empty_level = skip_empty_level
self.random_shuffler = RandomShuffler()
self.bucket_idx = 0
if task != CorpusTask.TRAIN and model_task == ModelTask.LANGUAGE_MODEL:
self.left_pad = True
else:
self.left_pad = False

@classmethod
def from_opt(
cls, corpora, transforms, vocabs, opt, task, copy, device, stride=1, offset=0
cls, corpora, transforms, vocabs, opt, task, model_task, copy, device, stride=1, offset=0
):
"""Initilize `DynamicDatasetIter` with options parsed from `opt`."""
corpora_info = {}
Expand Down Expand Up @@ -199,6 +204,7 @@ def from_opt(
transforms,
vocabs,
task,
model_task,
opt.batch_type,
batch_size,
batch_size_multiple,
Expand Down Expand Up @@ -354,7 +360,7 @@ def __iter__(self):
# within the batch
if self.task == CorpusTask.TRAIN:
minibatch.sort(key=lambda x: self.sort_key(x[0]), reverse=True)
tensor_batch = tensorify(self.vocabs, minibatch, self.device)
tensor_batch = tensorify(self.vocabs, minibatch, self.device, self.left_pad)
yield (tensor_batch, bucket_idx)


Expand Down Expand Up @@ -382,6 +388,7 @@ def build_dynamic_dataset_iter(
vocabs,
copy=False,
task=CorpusTask.TRAIN,
model_task=ModelTask.SEQ2SEQ,
stride=1,
offset=0,
src=None,
Expand Down Expand Up @@ -420,6 +427,7 @@ def build_dynamic_dataset_iter(
vocabs,
opt,
task,
model_task=model_task,
copy=copy,
stride=stride,
offset=offset,
Expand Down
41 changes: 29 additions & 12 deletions onmt/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def parse_align_idx(align_pharaoh):
return flatten_align_idx


def tensorify(vocabs, minibatch, device):
def tensorify(vocabs, minibatch, device, left_pad):
"""
This function transforms a batch of example in tensors
Each example looks like
Expand All @@ -193,21 +193,35 @@ def tensorify(vocabs, minibatch, device):
}
"""
tensor_batch = {}
tbatchsrc = [
torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device)
for ex, indice in minibatch
]
if left_pad:
tbatchsrc = [
torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device).flip(dims=[0])
for ex, indice in minibatch
]
else:
tbatchsrc = [
torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device)
for ex, indice in minibatch
]
padidx = vocabs["src"][DefaultTokens.PAD]
tbatchsrc = pad_sequence(tbatchsrc, batch_first=True, padding_value=padidx)
if "feats" in minibatch[0][0]["src"]:
tbatchfs = [tbatchsrc]
for feat_id in range(len(minibatch[0][0]["src"]["feats"])):
tbatchfeat = [
torch.tensor(
ex["src"]["feats"][feat_id], dtype=torch.long, device=device
)
for ex, indice in minibatch
]
if left_pad:
tbatchfeat = [
torch.tensor(
ex["src"]["feats"][feat_id], dtype=torch.long, device=device
).flip(dims=[0])
for ex, indice in minibatch
]
else:
tbatchfeat = [
torch.tensor(
ex["src"]["feats"][feat_id], dtype=torch.long, device=device
)
for ex, indice in minibatch
]
padidx = vocabs["src_feats"][feat_id][DefaultTokens.PAD]
tbatchfeat = pad_sequence(
tbatchfeat, batch_first=True, padding_value=padidx
Expand All @@ -218,7 +232,10 @@ def tensorify(vocabs, minibatch, device):
# Need to add features in last dimensions
tbatchsrc = tbatchsrc[:, :, None]

tensor_batch["src"] = tbatchsrc
if left_pad:
tensor_batch["src"] = tbatchsrc.flip(dims=[1])
else:
tensor_batch["src"] = tbatchsrc

tensor_batch["srclen"] = torch.tensor(
[len(ex["src"]["src_ids"]) for ex, indice in minibatch],
Expand Down
67 changes: 34 additions & 33 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,41 +491,42 @@ def forward(
self.flash2
and l > 256 # https://github.com/Dao-AILab/flash-attention/issues/591
)
if False:

if (
self.max_relative_positions in [-1, 0]
and not return_attn
and query.device != torch.device("cpu")
):
causal = self.is_decoder and self.attn_type == "self" and mask is not None
if self.is_decoder and self.attn_type == "self" and flash2:
if causal:
window_size = (
(-1, -1) if sliding_window == 0 else (sliding_window, 0)
)
if (
self.max_relative_positions in [-1, 0]
and not return_attn
and query.device != torch.device("cpu")
):
causal = self.is_decoder and self.attn_type == "self" and mask is not None
if self.is_decoder and self.attn_type == "self" and flash2:
if causal:
window_size = (
(-1, -1) if sliding_window == 0 else (sliding_window, 0)
)
else:
window_size = (-1, -1)
attn_output = self.flash_attn_func(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
dropout_p=self.dropout_p,
causal=causal,
window_size=window_size,
).transpose(1, 2)
else:
window_size = (-1, -1)
attn_output = self.flash_attn_func(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
dropout_p=self.dropout_p,
causal=causal,
window_size=window_size,
).transpose(1, 2)
else:
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_math=True, enable_mem_efficient=True
):
attn_output = scaled_dot_product_attention(
query,
key,
value,
~mask if mask is not None else None,
self.dropout_p,
is_causal=causal,
)
attn = None
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_math=True, enable_mem_efficient=True
):
attn_output = scaled_dot_product_attention(
query,
key,
value,
~mask if mask is not None else None,
self.dropout_p,
is_causal=causal,
)
attn = None

else:
query /= sqrt(self.dim_per_head)
Expand Down
20 changes: 8 additions & 12 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
self,
model,
vocabs,
model_task,
gpu=-1,
n_best=1,
min_length=0,
Expand Down Expand Up @@ -136,6 +137,7 @@ def __init__(
):
self.model = model
self.vocabs = vocabs
self.model_task = model_task
self._tgt_vocab = vocabs["tgt"]
self._tgt_eos_idx = vocabs["tgt"].lookup_token(DefaultTokens.EOS)
self._tgt_pad_idx = vocabs["tgt"].lookup_token(DefaultTokens.PAD)
Expand Down Expand Up @@ -241,6 +243,7 @@ def from_opt(
return cls(
model,
vocabs,
model_task=model_opt.model_task,
gpu=opt.gpu,
n_best=opt.n_best,
min_length=opt.min_length,
Expand Down Expand Up @@ -988,16 +991,6 @@ def _align_forward(self, batch, predictions):

def translate_batch(self, batch, attn_debug):
"""Translate a batch of sentences."""
batch_size = len(batch["srclen"])
if batch_size != 1:
warning_msg = (
"GeneratorLM does not support batch_size != 1"
" nicely. You can remove this limitation here."
" With batch_size > 1 the end of each input is"
" repeated until the input is finished. Then"
" generation will start."
)
self._log(warning_msg)
with torch.no_grad():
if self.sample_from_topk != 0 or self.sample_from_topp != 0:
decode_strategy = GreedySearchLM(
Expand Down Expand Up @@ -1061,7 +1054,7 @@ def tile_to_beam_size_after_initial_step(self, fn_map_state, log_probs):
log_probs = log_probs[:, -1, :]
return log_probs

def _translate_batch_with_strategy(self, batch, decode_strategy):
def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True):
"""Translate a batch of sentences step by step using cache.
Args:
Expand All @@ -1081,7 +1074,10 @@ def _translate_batch_with_strategy(self, batch, decode_strategy):
src = batch["src"]
src_len = batch["srclen"]

src, src_len, target_prefix = self.split_src_to_prevent_padding(src, src_len)
if left_pad:
target_prefix = None
else:
src, src_len, target_prefix = self.split_src_to_prevent_padding(src, src_len)

# (2) init decoder
self.model.decoder.init_state(src, None, None)
Expand Down

0 comments on commit 802ec8c

Please sign in to comment.