4
4
import os
5
5
import re
6
6
from pathlib import Path
7
- from typing import Any , Callable , List , Optional , Union , cast
7
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union , cast
8
8
9
9
import torch # pyright: ignore[reportMissingImports]
10
+ from accelerate import Accelerator # pyright: ignore[reportMissingImports]
11
+ from accelerate .utils .memory import should_reduce_batch_size # pyright: ignore[reportMissingImports]
10
12
from datasets .arrow_dataset import Dataset
11
13
from sacremoses import MosesPunctNormalizer
12
14
from torch import Tensor # pyright: ignore[reportMissingImports]
15
+ from torch .nn import Module # pyright: ignore[reportMissingImports]
16
+ from torch .optim .lr_scheduler import LambdaLR # pyright: ignore[reportMissingImports]
17
+ from torch .optim .optimizer import Optimizer # pyright: ignore[reportMissingImports]
13
18
from torch .utils .checkpoint import checkpoint # pyright: ignore[reportMissingImports] # noqa: F401
14
19
from transformers import (
15
20
AutoConfig ,
16
21
AutoModelForSeq2SeqLM ,
17
22
AutoTokenizer ,
18
23
DataCollatorForSeq2Seq ,
24
+ EvalPrediction ,
19
25
M2M100ForConditionalGeneration ,
20
26
M2M100Tokenizer ,
27
+ MBart50Tokenizer ,
21
28
MBart50TokenizerFast ,
22
29
MBartTokenizer ,
23
30
MBartTokenizerFast ,
24
31
NllbTokenizer ,
25
32
NllbTokenizerFast ,
26
33
PreTrainedModel ,
27
34
PreTrainedTokenizer ,
35
+ PreTrainedTokenizerBase ,
28
36
PreTrainedTokenizerFast ,
29
37
Seq2SeqTrainer ,
30
38
Seq2SeqTrainingArguments ,
31
39
TrainerCallback ,
32
40
set_seed ,
33
41
)
34
- from transformers .models .mbart50 import MBart50Tokenizer
35
42
from transformers .trainer_callback import TrainerControl , TrainerState
36
43
from transformers .trainer_utils import get_last_checkpoint
37
44
from transformers .training_args import TrainingArguments
@@ -315,7 +322,7 @@ def preprocess_function(examples):
315
322
pad_to_multiple_of = 8 if self ._training_args .fp16 else None ,
316
323
)
317
324
318
- self ._trainer = Seq2SeqTrainer (
325
+ self ._trainer = AutoGradientAccumulationStepsSeq2SeqTrainer (
319
326
model = model ,
320
327
args = self ._training_args ,
321
328
train_dataset = cast (Any , train_dataset ),
@@ -372,10 +379,12 @@ def __init__(
372
379
max_steps : Optional [int ],
373
380
progress : Optional [Callable [[ProgressStatus ], None ]],
374
381
check_canceled : Optional [Callable [[], None ]],
382
+ update_frequency : Optional [int ] = None ,
375
383
) -> None :
376
- self ._max_steps = max_steps
384
+ self ._max_steps = max_steps if max_steps is not None else 0
377
385
self ._progress = progress
378
386
self ._check_canceled = check_canceled
387
+ self ._update_frequency = update_frequency if update_frequency is not None else max ((self ._max_steps // 100 ), 1 )
379
388
380
389
def on_train_begin (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ) -> None :
381
390
if self ._check_canceled is not None :
@@ -387,6 +396,9 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
387
396
)
388
397
389
398
def on_step_end (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ) -> None :
399
+ if (state .global_step % self ._update_frequency ) != 0 :
400
+ return
401
+
390
402
if self ._check_canceled is not None :
391
403
self ._check_canceled ()
392
404
@@ -398,6 +410,73 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
398
410
)
399
411
400
412
413
+ class AutoGradientAccumulationStepsSeq2SeqTrainer (Seq2SeqTrainer ):
414
+ def __init__ (
415
+ self ,
416
+ model : Union [PreTrainedModel , Module ],
417
+ args : Seq2SeqTrainingArguments ,
418
+ data_collator : Any ,
419
+ train_dataset : Optional [Dataset ] = None ,
420
+ eval_dataset : Optional [Union [Dataset , Dict [str , Dataset ]]] = None ,
421
+ tokenizer : Optional [PreTrainedTokenizerBase ] = None ,
422
+ model_init : Optional [Callable [[], PreTrainedModel ]] = None ,
423
+ compute_metrics : Optional [Callable [[EvalPrediction ], Dict ]] = None ,
424
+ callbacks : Optional [List [TrainerCallback ]] = None ,
425
+ optimizers : Tuple [Optional [Optimizer ], Optional [LambdaLR ]] = (None , None ),
426
+ preprocess_logits_for_metrics : Optional [Callable [[Tensor , Tensor ], Tensor ]] = None ,
427
+ ):
428
+ super ().__init__ (
429
+ model ,
430
+ args ,
431
+ data_collator ,
432
+ train_dataset , # type: ignore
433
+ eval_dataset , # type: ignore
434
+ tokenizer ,
435
+ model_init ,
436
+ compute_metrics ,
437
+ callbacks ,
438
+ optimizers , # type: ignore
439
+ preprocess_logits_for_metrics ,
440
+ )
441
+
442
+ def _inner_training_loop (
443
+ self , batch_size = None , args = None , resume_from_checkpoint = None , trial = None , ignore_keys_for_eval = None
444
+ ):
445
+ inner_training_loop = find_executable_batch_size (super ()._inner_training_loop , batch_size , self .accelerator )
446
+ return inner_training_loop (
447
+ args = args ,
448
+ resume_from_checkpoint = resume_from_checkpoint ,
449
+ trial = trial ,
450
+ ignore_keys_for_eval = ignore_keys_for_eval ,
451
+ )
452
+
453
+
454
+ def find_executable_batch_size (function : Callable , starting_batch_size , accelerator : Accelerator ):
455
+ batch_size = starting_batch_size
456
+
457
+ def decorator (* args , ** kwargs ):
458
+ nonlocal batch_size
459
+ gc .collect ()
460
+ torch .cuda .empty_cache ()
461
+
462
+ while True :
463
+ if batch_size == 0 :
464
+ raise RuntimeError ("No executable batch size found, reached zero." )
465
+ try :
466
+ return function (batch_size , * args , ** kwargs )
467
+ except Exception as e :
468
+ if should_reduce_batch_size (e ):
469
+ gc .collect ()
470
+ torch .cuda .empty_cache ()
471
+ batch_size //= 2
472
+ accelerator .gradient_accumulation_steps = accelerator .gradient_accumulation_steps * 2
473
+ kwargs ["args" ].gradient_accumulation_steps = accelerator .gradient_accumulation_steps
474
+ else :
475
+ raise
476
+
477
+ return decorator
478
+
479
+
401
480
def add_lang_code_to_tokenizer (tokenizer : Union [PreTrainedTokenizer , PreTrainedTokenizerFast ], lang_code : str ):
402
481
if isinstance (tokenizer , M2M100Tokenizer ):
403
482
lang_token = "__" + lang_code + "__"
0 commit comments