Skip to content

Commit 3a79c67

Browse files
authored
Faster NMT training for machine.py (#227)
* Use auto gradient accumulation steps via accelerate * Report progress only at the 'percent' step * Update transformers * Set output_attentions to False in nmt_model_factory to enable SDPA
1 parent 0e707ce commit 3a79c67

File tree

6 files changed

+2327
-2299
lines changed

6 files changed

+2327
-2299
lines changed

machine/jobs/build_nmt_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
def run(args: dict) -> None:
2626
progress: Optional[Callable[[ProgressStatus], None]] = None
2727
check_canceled: Optional[Callable[[], None]] = None
28-
task = None
28+
task: Optional[Task] = None
2929
if args["clearml"]:
3030
task = Task.init()
3131

machine/jobs/huggingface/hugging_face_nmt_model_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def create_engine(self) -> TranslationEngine:
9999
batch_size=self._config.huggingface.generate_params.batch_size,
100100
truncation=TruncationStrategy.LONGEST_FIRST,
101101
oom_batch_size_backoff_mult=self._config.huggingface.generate_params.oom_batch_size_backoff_mult,
102+
output_attentions=False,
102103
)
103104

104105
def save_model(self) -> Path:

machine/jobs/settings.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ default:
1010
do_train: true
1111
optim: adamw_torch
1212
warmup_steps: 1000
13-
per_device_train_batch_size: 16
14-
gradient_accumulation_steps: 4
13+
per_device_train_batch_size: 64
14+
gradient_accumulation_steps: 1
1515
label_smoothing_factor: 0.2
1616
group_by_length: true
1717
gradient_checkpointing: true
1818
lr_scheduler_type: cosine
1919
learning_rate: 0.0002
2020
fp16: true
21+
tf32: true
2122
save_strategy: no
2223
max_steps: 5000
2324
generate_params:

machine/translation/huggingface/hugging_face_nmt_model_trainer.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,41 @@
44
import os
55
import re
66
from pathlib import Path
7-
from typing import Any, Callable, List, Optional, Union, cast
7+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
88

99
import torch # pyright: ignore[reportMissingImports]
10+
from accelerate import Accelerator # pyright: ignore[reportMissingImports]
11+
from accelerate.utils.memory import should_reduce_batch_size # pyright: ignore[reportMissingImports]
1012
from datasets.arrow_dataset import Dataset
1113
from sacremoses import MosesPunctNormalizer
1214
from torch import Tensor # pyright: ignore[reportMissingImports]
15+
from torch.nn import Module # pyright: ignore[reportMissingImports]
16+
from torch.optim.lr_scheduler import LambdaLR # pyright: ignore[reportMissingImports]
17+
from torch.optim.optimizer import Optimizer # pyright: ignore[reportMissingImports]
1318
from torch.utils.checkpoint import checkpoint # pyright: ignore[reportMissingImports] # noqa: F401
1419
from transformers import (
1520
AutoConfig,
1621
AutoModelForSeq2SeqLM,
1722
AutoTokenizer,
1823
DataCollatorForSeq2Seq,
24+
EvalPrediction,
1925
M2M100ForConditionalGeneration,
2026
M2M100Tokenizer,
27+
MBart50Tokenizer,
2128
MBart50TokenizerFast,
2229
MBartTokenizer,
2330
MBartTokenizerFast,
2431
NllbTokenizer,
2532
NllbTokenizerFast,
2633
PreTrainedModel,
2734
PreTrainedTokenizer,
35+
PreTrainedTokenizerBase,
2836
PreTrainedTokenizerFast,
2937
Seq2SeqTrainer,
3038
Seq2SeqTrainingArguments,
3139
TrainerCallback,
3240
set_seed,
3341
)
34-
from transformers.models.mbart50 import MBart50Tokenizer
3542
from transformers.trainer_callback import TrainerControl, TrainerState
3643
from transformers.trainer_utils import get_last_checkpoint
3744
from transformers.training_args import TrainingArguments
@@ -315,7 +322,7 @@ def preprocess_function(examples):
315322
pad_to_multiple_of=8 if self._training_args.fp16 else None,
316323
)
317324

318-
self._trainer = Seq2SeqTrainer(
325+
self._trainer = AutoGradientAccumulationStepsSeq2SeqTrainer(
319326
model=model,
320327
args=self._training_args,
321328
train_dataset=cast(Any, train_dataset),
@@ -372,10 +379,12 @@ def __init__(
372379
max_steps: Optional[int],
373380
progress: Optional[Callable[[ProgressStatus], None]],
374381
check_canceled: Optional[Callable[[], None]],
382+
update_frequency: Optional[int] = None,
375383
) -> None:
376-
self._max_steps = max_steps
384+
self._max_steps = max_steps if max_steps is not None else 0
377385
self._progress = progress
378386
self._check_canceled = check_canceled
387+
self._update_frequency = update_frequency if update_frequency is not None else max((self._max_steps // 100), 1)
379388

380389
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
381390
if self._check_canceled is not None:
@@ -387,6 +396,9 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
387396
)
388397

389398
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
399+
if (state.global_step % self._update_frequency) != 0:
400+
return
401+
390402
if self._check_canceled is not None:
391403
self._check_canceled()
392404

@@ -398,6 +410,73 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
398410
)
399411

400412

413+
class AutoGradientAccumulationStepsSeq2SeqTrainer(Seq2SeqTrainer):
414+
def __init__(
415+
self,
416+
model: Union[PreTrainedModel, Module],
417+
args: Seq2SeqTrainingArguments,
418+
data_collator: Any,
419+
train_dataset: Optional[Dataset] = None,
420+
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
421+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
422+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
423+
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
424+
callbacks: Optional[List[TrainerCallback]] = None,
425+
optimizers: Tuple[Optional[Optimizer], Optional[LambdaLR]] = (None, None),
426+
preprocess_logits_for_metrics: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
427+
):
428+
super().__init__(
429+
model,
430+
args,
431+
data_collator,
432+
train_dataset, # type: ignore
433+
eval_dataset, # type: ignore
434+
tokenizer,
435+
model_init,
436+
compute_metrics,
437+
callbacks,
438+
optimizers, # type: ignore
439+
preprocess_logits_for_metrics,
440+
)
441+
442+
def _inner_training_loop(
443+
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
444+
):
445+
inner_training_loop = find_executable_batch_size(super()._inner_training_loop, batch_size, self.accelerator)
446+
return inner_training_loop(
447+
args=args,
448+
resume_from_checkpoint=resume_from_checkpoint,
449+
trial=trial,
450+
ignore_keys_for_eval=ignore_keys_for_eval,
451+
)
452+
453+
454+
def find_executable_batch_size(function: Callable, starting_batch_size, accelerator: Accelerator):
455+
batch_size = starting_batch_size
456+
457+
def decorator(*args, **kwargs):
458+
nonlocal batch_size
459+
gc.collect()
460+
torch.cuda.empty_cache()
461+
462+
while True:
463+
if batch_size == 0:
464+
raise RuntimeError("No executable batch size found, reached zero.")
465+
try:
466+
return function(batch_size, *args, **kwargs)
467+
except Exception as e:
468+
if should_reduce_batch_size(e):
469+
gc.collect()
470+
torch.cuda.empty_cache()
471+
batch_size //= 2
472+
accelerator.gradient_accumulation_steps = accelerator.gradient_accumulation_steps * 2
473+
kwargs["args"].gradient_accumulation_steps = accelerator.gradient_accumulation_steps
474+
else:
475+
raise
476+
477+
return decorator
478+
479+
401480
def add_lang_code_to_tokenizer(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_code: str):
402481
if isinstance(tokenizer, M2M100Tokenizer):
403482
lang_token = "__" + lang_code + "__"

0 commit comments

Comments
 (0)