From f8757027193aea5f7165d8f21c81baee9e5643fc Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Wed, 14 Jun 2023 15:50:36 -0700 Subject: [PATCH] Apply garbage collection interval to validation steps (#6870) * Apply garbage collection inverval to validation steps Signed-off-by: Sangkug Lym * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../language_modeling/megatron_base_model.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index b96a5adb0c62..f3ae5a938a61 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -157,6 +157,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): # The automatic garbage collector sould be disabled before training starts. if self.gc_interval > 0: gc.disable() + self.validation_global_step = 1 def _enable_nvidia_optimizations(self): "These optimizations are present in NVIDIA NGC PyTorch Containers" @@ -218,6 +219,16 @@ def on_train_start(self) -> None: super().on_train_start() self.init_global_step = self.trainer.global_step + def on_validation_start(self) -> None: + super().on_validation_start() + if self.gc_interval > 0: + gc.collect() + + def on_validation_end(self) -> None: + super().on_validation_end() + if self.gc_interval > 0: + gc.collect() + def _build_vocab(self): """ Manipulate vocabulary (e.g., pad vocabulary for increased performance)/ @@ -366,6 +377,14 @@ def on_train_batch_end(self, outputs, dataloader_iter: Any, batch_idx: int, unus if self.gc_interval > 0 and (self.trainer.global_step % self.gc_interval == 0): gc.collect() + def on_validation_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) + + if self.gc_interval > 0: + if self.validation_global_step % self.gc_interval == 0: + gc.collect() + self.validation_global_step += 1 + def setup_optimization( self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, ):