Skip to content

Commit

Permalink
Use native CrossEntropyLoss including label_smoothing + more optimisa…
Browse files Browse the repository at this point in the history
…tions (OpenNMT#2270)

* use native crossentropy
* doc
* fix transform bug
  • Loading branch information
vince62s authored Dec 9, 2022
1 parent 9698acd commit 9d617b8
Show file tree
Hide file tree
Showing 18 changed files with 55 additions and 95 deletions.
1 change: 0 additions & 1 deletion .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ jobs:
-tgt_vocab /tmp/onmt.vocab.tgt \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-max_generator_batches 0 \
-encoder_type transformer \
-decoder_type transformer \
-layers 4 \
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ Unless there is a bug, please use the [forum](https://forum.opennmt.net) or [Git

----

There is a new step-by-step and explained tuto (Thanks to Yasmin Moslem) here:
Please try to read and/or follow before raising newbie issues [Tutorial](https://github.com/ymoslem/OpenNMT-Tutorial)

----

# OpenNMT-py 3.0

**We're happy to announce the release v3.0 of OpenNMT-py.**
Expand Down
2 changes: 0 additions & 2 deletions config/config-transformer-base-1GPU.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ normalization: tokens
dropout: 0.1
label_smoothing: 0.1

max_generator_batches: 2

param_init: 0.0
param_init_glorot: 'true'
position_encoding: 'true'
Expand Down
2 changes: 0 additions & 2 deletions config/config-transformer-base-4GPU.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ normalization: tokens
dropout: 0.1
label_smoothing: 0.1

max_generator_batches: 2

param_init: 0.0
param_init_glorot: 'true'
position_encoding: 'true'
Expand Down
3 changes: 1 addition & 2 deletions docs/source/FAQ.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# FAQ
All the example YAML configurations are partial. To get an overview of what this YAML configuration is you can start by reading the [Quickstart](quickstart) section.

Also you can have a look at this: [Tutorial]https://github.com/ymoslem/OpenNMT-Tutorial
Also you can have a look at this: [Tutorial](https://github.com/ymoslem/OpenNMT-Tutorial)

## How do I use my v2 models in v3 ?

Expand Down Expand Up @@ -58,7 +58,6 @@ num_workers: 2
batch_type: "tokens"
batch_size: 4096
valid_batch_size: 2048
max_generator_batches: 2
accum_count: [4]
accum_steps: [0]

Expand Down
1 change: 0 additions & 1 deletion docs/source/examples/Translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ batch_type: "tokens"
batch_size: 4096
valid_batch_size: 2048
batch_size_multiple: 8
max_generator_batches: 0
accum_count: [3]
accum_steps: [0]
Expand Down
1 change: 1 addition & 0 deletions docs/source/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

This portal provides a detailed documentation of the OpenNMT-py toolkit. It describes how to use the PyTorch project and how it works.

If you need a step-by-step and overview, please read here: [Tutorial](https://github.com/ymoslem/OpenNMT-Tutorial)


## Installation
Expand Down
1 change: 0 additions & 1 deletion examples/onmt.train.fp16.transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ batch_type: "tokens"
batch_size: 4096
valid_batch_size: 8
batch_size_multiple: 8
max_generator_batches: 0
accum_count: [3]
accum_steps: [0]

Expand Down
1 change: 0 additions & 1 deletion examples/wmt14_en_de.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ batch_type: "tokens"
batch_size: 4096
valid_batch_size: 2048
batch_size_multiple: 8
max_generator_batches: 0
accum_count: [3]
accum_steps: [0]

Expand Down
34 changes: 16 additions & 18 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from onmt.decoders import str2dec
from onmt.inputters.inputter import dict_to_vocabs
from onmt.modules import Embeddings, CopyGenerator
from onmt.modules.util_class import Cast
from onmt.utils.misc import use_gpu
from onmt.utils.logging import logger
from onmt.utils.parse import ArgumentParser
Expand Down Expand Up @@ -192,19 +191,19 @@ def use_embeddings_from_checkpoint(vocabs, model, generator, checkpoint):
emb_name
][old_i]
if side == 'tgt':
generator.state_dict()['0.weight'][i] = checkpoint[
generator.state_dict()['weight'][i] = checkpoint[
'generator'
]['0.weight'][old_i]
generator.state_dict()['0.bias'][i] = checkpoint[
]['weight'][old_i]
generator.state_dict()['bias'][i] = checkpoint[
'generator'
]['0.bias'][old_i]
]['bias'][old_i]
else:
# Just for debugging purposes
new_tokens.append(tok)
logger.info("%s: %d new tokens" % (side, len(new_tokens)))
# Remove old vocabulary associated embeddings
del checkpoint['model'][emb_name]
del checkpoint['generator']['0.weight'], checkpoint['generator']['0.bias']
del checkpoint['generator']['weight'], checkpoint['generator']['bias']


def build_base_model(model_opt, vocabs, gpu, checkpoint=None, gpu_id=None):
Expand Down Expand Up @@ -243,18 +242,10 @@ def build_base_model(model_opt, vocabs, gpu, checkpoint=None, gpu_id=None):

# Build Generator.
if not model_opt.copy_attn:
if model_opt.generator_function == "sparsemax":
gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1)
else:
gen_func = nn.LogSoftmax(dim=-1)
generator = nn.Sequential(
nn.Linear(model_opt.dec_hid_size,
len(vocabs['tgt'])),
Cast(torch.float32),
gen_func
)
generator = nn.Linear(model_opt.dec_hid_size,
len(vocabs['tgt']))
if model_opt.share_decoder_embeddings:
generator[0].weight = model.decoder.embeddings.word_lut.weight
generator.weight = model.decoder.embeddings.word_lut.weight
else:
vocab_size = len(vocabs['tgt'])
pad_idx = vocabs['tgt'][DefaultTokens.PAD]
Expand Down Expand Up @@ -295,8 +286,15 @@ def fix_key(s):

checkpoint['model'] = {fix_key(k): v
for k, v in checkpoint['model'].items()}
# end of patch for backward compatibility

if '0.weight' in checkpoint['generator']:
checkpoint['generator']['weight'] =\
checkpoint['generator'].pop('0.weight')
if '0.bias' in checkpoint['generator']:
checkpoint['generator']['bias'] =\
checkpoint['generator'].pop('0.bias')

# end of patch for backward compatibility
if model_opt.update_vocab:
# Update model embeddings with those from the checkpoint
# after initialization
Expand Down
1 change: 0 additions & 1 deletion onmt/tests/pull_request_chk.sh
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ ${PYTHON} onmt/bin/train.py \
-tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-max_generator_batches 0 \
-num_workers 0 -bucket_size 1024 \
-encoder_type transformer -decoder_type transformer \
-layers 4 -word_vec_size 16 -hidden_size 16 -heads 2 -transformer_ff 64 \
Expand Down
4 changes: 2 additions & 2 deletions onmt/tests/rebuild_test_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ $my_python train.py \
-encoder_type transformer -decoder_type transformer \
-share_embedding -share_vocab \
-train_steps 1000 -world_size 1 -gpu_ranks 0 \
-max_generator_batches 2 -dropout 0.1 \
-dropout 0.1 \
-normalization tokens \
-max_grad_norm 0 -optim adam -decay_method noam \
-learning_rate 2 -label_smoothing 0.1 \
Expand Down Expand Up @@ -110,7 +110,7 @@ $my_python build_vocab.py \
$my_python train.py -config data/lm_data.yaml -save_model /tmp/tmp \
-accum_count 2 -dec_layers 2 -hidden_size 64 -word_vec_size 64 -batch_size 256 \
-encoder_type transformer_lm -decoder_type transformer_lm -share_embedding \
-train_steps 2000 -max_generator_batches 4 -dropout 0.1 -normalization tokens \
-train_steps 2000 -dropout 0.1 -normalization tokens \
-share_vocab -transformer_ff 256 -max_grad_norm 0 -optim adam -decay_method noam \
-learning_rate 2 -label_smoothing 0.1 -model_task lm -world_size 1 -gpu_ranks 0 \
-attention_dropout 0.1 -heads 2 -position_encoding -param_init 0 -warmup_steps 100 \
Expand Down
1 change: 0 additions & 1 deletion onmt/tests/test_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ transformer(){
-decoder_type transformer \
-train_steps 10000 \
-gpuid $GPUID \
-max_generator_batches 4 \
-dropout 0.1 \
-normalization tokens \
-max_grad_norm 0 \
Expand Down
5 changes: 4 additions & 1 deletion onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def train(self,
total_stats = onmt.utils.Statistics()
report_stats = onmt.utils.Statistics()
self._start_report_manager(start_time=total_stats.start_time)
# Let's clean the GPUs before training loop
torch.cuda.empty_cache()

for i, batches in enumerate(
self._accum_batches(train_iter)):
Expand All @@ -277,7 +279,7 @@ def train(self,
valid_stats = self.validate(
valid_iter, moving_average=self.moving_average)

if step % valid_steps == 0:
if step % valid_steps == 0 and self.gpu_rank <= 0:
self._report_step(self.optim.learning_rate(),
step, valid_stats=valid_stats,
train_stats=total_stats)
Expand Down Expand Up @@ -488,6 +490,7 @@ def _gradient_accumulation(self, true_batches, total_stats,
if "CUDA out of memory" in trace_content:
logger.info("Step %d, cuda OOM - batch removed",
self.optim.training_step)
torch.cuda.empty_cache()
else:
traceback.print_exc()
raise exc
Expand Down
5 changes: 4 additions & 1 deletion onmt/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ def batch_apply(self, batch, is_train=False, **kwargs):
batch (list): a list of examples;
is_train (bool): Indicate if src/tgt is training data;bject.
"""
transformed_batch = []
for (example, _, cid) in batch:
example = self.apply(example, is_train=is_train, **kwargs)
return batch
if example is not None:
transformed_batch.append((example, self, cid))
return transformed_batch

def apply_reverse(self, translated):
return translated
Expand Down
5 changes: 3 additions & 2 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from itertools import count, zip_longest

import torch

import torch.nn.functional as F
from onmt.constants import DefaultTokens
import onmt.model_builder
import onmt.decoders.ensemble
Expand Down Expand Up @@ -547,7 +547,8 @@ def _decode_and_generate(
else:
attn = None

log_probs = self.model.generator(dec_out.squeeze(1))
scores = self.model.generator(dec_out.squeeze(1))
log_probs = F.log_softmax(scores.to(torch.float32), dim=-1)
# returns [(batch_size x beam_size) , vocab ] when 1 step
# or [batch_size, tgt_len, vocab ] when full sentence
else:
Expand Down
72 changes: 16 additions & 56 deletions onmt/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,22 @@ def from_opts(cls, opt, model, vocab, train=True):
"order to use --lambda_coverage != 0"

tgt_shift_idx = 1 if opt.model_task == ModelTask.SEQ2SEQ else 0

label_smoothing = opt.label_smoothing if train else 0
if opt.copy_attn:
criterion = onmt.modules.CopyGeneratorLoss(
len(vocab), opt.copy_attn_force,
unk_index=unk_idx, ignore_index=padding_idx
)
else:
if opt.label_smoothing > 0 and train:
criterion = LabelSmoothingLoss(
opt.label_smoothing, len(vocab),
ignore_index=padding_idx
)
elif isinstance(model.generator[-1], LogSparsemax):
if opt.generator_function == 'sparsemax':
criterion = SparsemaxLoss(ignore_index=padding_idx,
reduction='sum')
else:
criterion = nn.NLLLoss(ignore_index=padding_idx,
reduction='sum')
criterion = nn.CrossEntropyLoss(
ignore_index=padding_idx,
reduction='sum',
label_smoothing=label_smoothing
)

lm_prior_lambda = opt.lm_prior_lambda
lm_prior_tau = opt.lm_prior_tau
Expand All @@ -120,15 +118,7 @@ def from_opts(cls, opt, model, vocab, train=True):
lm_generator = None
lm_prior_model = None

# if the loss function operates on vectors of raw logits instead
# of probabilities, only the first part of the generator needs to
# be passed to the NMTLossCompute. At the moment, the only
# supported loss function of this kind is the sparsemax loss.
use_raw_logits = isinstance(criterion, SparsemaxLoss)
loss_gen = model.generator[0] if use_raw_logits \
else model.generator

compute = cls(criterion, loss_gen,
compute = cls(criterion, model.generator,
normalization=opt.normalization,
copy_attn=opt.copy_attn,
lambda_coverage=opt.lambda_coverage,
Expand Down Expand Up @@ -188,9 +178,8 @@ def _compute_lm_loss_ct2(self, output, target):
/fairseq_extension/user/lm_prior/lm_prior.py#L131-L133
"""

# we use the raw logits, rescale with tau (temperature) and
# apply the log_softmax. reminder generator[0] is just the nn.Linear
scores = self.generator[0](self._bottle(output)) / self.lm_prior_tau
# rescale with tau (temperature) and apply the log_softmax.
scores = self.generator(self._bottle(output)) / self.lm_prior_tau
scores = F.log_softmax(scores.to(torch.float32), dim=-1)

src = target.detach().clone()
Expand Down Expand Up @@ -223,9 +212,8 @@ def _compute_lm_loss(self, output, target):
/fairseq_extension/user/lm_prior/lm_prior.py#L131-L133
"""

# we use the raw logits, rescale with tau (temperature) and
# apply the log_softmax. reminder generator[0] is just the nn.Linear
scores = self.generator[0](self._bottle(output)) / self.lm_prior_tau
# rescale with tau (temperature) and apply the log_softmax.
scores = self.generator(self._bottle(output)) / self.lm_prior_tau
scores = F.log_softmax(scores.to(torch.float32), dim=-1)

src = target.detach().clone()
Expand All @@ -234,7 +222,7 @@ def _compute_lm_loss(self, output, target):
# ct2 expects src with lengths without padding
lm_outs, _ = self.lm_prior_model(src, None, src_len,
with_align=False)
lm_scores = self.lm_prior_model.generator[0](
lm_scores = self.lm_prior_model.generator(
self._bottle(lm_outs)) / self.lm_prior_tau
# again we use raw probs to rescale with tau and apply log_softmax
lm_scores = F.log_softmax(lm_scores.to(torch.float32), dim=-1)
Expand Down Expand Up @@ -308,7 +296,9 @@ def forward(self, batch, output, attns,
else:

scores = self.generator(self._bottle(output))
loss = self.criterion(scores, flat_tgt)
if isinstance(self.criterion, SparsemaxLoss):
scores = LogSparsemax(scores.to(torch.float32), dim=-1)
loss = self.criterion(scores.to(torch.float32), flat_tgt)

if self.lambda_align != 0.0:
align_head = attns['align']
Expand Down Expand Up @@ -374,33 +364,3 @@ def _stats(self, bsz, loss, scores, target):
n_sents=bsz,
n_words=num_non_padding,
n_correct=num_correct)


class LabelSmoothingLoss(nn.Module):
"""
With label smoothing,
KL-divergence between q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
"""
def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
assert 0.0 < label_smoothing <= 1.0
self.ignore_index = ignore_index
super(LabelSmoothingLoss, self).__init__()

smoothing_value = label_smoothing / (tgt_vocab_size - 2)
one_hot = torch.full((tgt_vocab_size,), smoothing_value)
one_hot[self.ignore_index] = 0
self.register_buffer('one_hot', one_hot.unsqueeze(0))

self.confidence = 1.0 - label_smoothing

def forward(self, output, target):
"""
output (FloatTensor): ``(batch_size, n_classes)``
target (LongTensor): ``(batch_size)``
"""
model_prob = self.one_hot.repeat(target.size(0), 1)
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)

return F.kl_div(output, model_prob, reduction='sum')
6 changes: 3 additions & 3 deletions tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def main():

vocabs, model, model_opt = load_test_model(opt)
padding_idx = vocabs['tgt'][DefaultTokens.PAD]
criterion = torch.nn.NLLLoss(ignore_index=padding_idx, reduction='none')
loss_gen = model.generator
valid_loss = LossCompute(criterion, loss_gen,
criterion = torch.nn.CrossEntropyLoss(ignore_index=padding_idx,
reduction='none')
valid_loss = LossCompute(criterion, model.generator,
normalization="tokens",
tgt_shift_index=0,
lambda_coverage=model_opt.lambda_coverage,
Expand Down

0 comments on commit 9d617b8

Please sign in to comment.