Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion machine/jobs/build_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def run(args: dict) -> None:
progress: Optional[Callable[[ProgressStatus], None]] = None
check_canceled: Optional[Callable[[], None]] = None
task = None
task: Optional[Task] = None
if args["clearml"]:
task = Task.init()

Expand Down
1 change: 1 addition & 0 deletions machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def create_engine(self) -> TranslationEngine:
batch_size=self._config.huggingface.generate_params.batch_size,
truncation=TruncationStrategy.LONGEST_FIRST,
oom_batch_size_backoff_mult=self._config.huggingface.generate_params.oom_batch_size_backoff_mult,
output_attentions=False,
)

def save_model(self) -> Path:
Expand Down
5 changes: 3 additions & 2 deletions machine/jobs/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ default:
do_train: true
optim: adamw_torch
warmup_steps: 1000
per_device_train_batch_size: 16
gradient_accumulation_steps: 4
per_device_train_batch_size: 64
gradient_accumulation_steps: 1
label_smoothing_factor: 0.2
group_by_length: true
gradient_checkpointing: true
lr_scheduler_type: cosine
learning_rate: 0.0002
fp16: true
tf32: true
save_strategy: no
max_steps: 5000
generate_params:
Expand Down
87 changes: 83 additions & 4 deletions machine/translation/huggingface/hugging_face_nmt_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,41 @@
import os
import re
from pathlib import Path
from typing import Any, Callable, List, Optional, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

import torch # pyright: ignore[reportMissingImports]
from accelerate import Accelerator # pyright: ignore[reportMissingImports]
from accelerate.utils.memory import should_reduce_batch_size # pyright: ignore[reportMissingImports]
from datasets.arrow_dataset import Dataset
from sacremoses import MosesPunctNormalizer
from torch import Tensor # pyright: ignore[reportMissingImports]
from torch.nn import Module # pyright: ignore[reportMissingImports]
from torch.optim.lr_scheduler import LambdaLR # pyright: ignore[reportMissingImports]
from torch.optim.optimizer import Optimizer # pyright: ignore[reportMissingImports]
from torch.utils.checkpoint import checkpoint # pyright: ignore[reportMissingImports] # noqa: F401
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
EvalPrediction,
M2M100ForConditionalGeneration,
M2M100Tokenizer,
MBart50Tokenizer,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
NllbTokenizer,
NllbTokenizerFast,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
TrainerCallback,
set_seed,
)
from transformers.models.mbart50 import MBart50Tokenizer
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import TrainingArguments
Expand Down Expand Up @@ -315,7 +322,7 @@ def preprocess_function(examples):
pad_to_multiple_of=8 if self._training_args.fp16 else None,
)

self._trainer = Seq2SeqTrainer(
self._trainer = AutoGradientAccumulationStepsSeq2SeqTrainer(
model=model,
args=self._training_args,
train_dataset=cast(Any, train_dataset),
Expand Down Expand Up @@ -372,10 +379,12 @@ def __init__(
max_steps: Optional[int],
progress: Optional[Callable[[ProgressStatus], None]],
check_canceled: Optional[Callable[[], None]],
update_frequency: Optional[int] = None,
) -> None:
self._max_steps = max_steps
self._max_steps = max_steps if max_steps is not None else 0
self._progress = progress
self._check_canceled = check_canceled
self._update_frequency = update_frequency if update_frequency is not None else max((self._max_steps // 100), 1)

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
if self._check_canceled is not None:
Expand All @@ -387,6 +396,9 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
)

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
if (state.global_step % self._update_frequency) != 0:
return

if self._check_canceled is not None:
self._check_canceled()

Expand All @@ -398,6 +410,73 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
)


class AutoGradientAccumulationStepsSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(
self,
model: Union[PreTrainedModel, Module],
args: Seq2SeqTrainingArguments,
data_collator: Any,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[Optional[Optimizer], Optional[LambdaLR]] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
):
super().__init__(
model,
args,
data_collator,
train_dataset, # type: ignore
eval_dataset, # type: ignore
tokenizer,
model_init,
compute_metrics,
callbacks,
optimizers, # type: ignore
preprocess_logits_for_metrics,
)

def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
inner_training_loop = find_executable_batch_size(super()._inner_training_loop, batch_size, self.accelerator)
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)


def find_executable_batch_size(function: Callable, starting_batch_size, accelerator: Accelerator):
batch_size = starting_batch_size

def decorator(*args, **kwargs):
nonlocal batch_size
gc.collect()
torch.cuda.empty_cache()

while True:
if batch_size == 0:
raise RuntimeError("No executable batch size found, reached zero.")
try:
return function(batch_size, *args, **kwargs)
except Exception as e:
if should_reduce_batch_size(e):
gc.collect()
torch.cuda.empty_cache()
batch_size //= 2
accelerator.gradient_accumulation_steps = accelerator.gradient_accumulation_steps * 2
kwargs["args"].gradient_accumulation_steps = accelerator.gradient_accumulation_steps
else:
raise

return decorator


def add_lang_code_to_tokenizer(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_code: str):
if isinstance(tokenizer, M2M100Tokenizer):
lang_token = "__" + lang_code + "__"
Expand Down
Loading
Loading