diff --git a/mart/callbacks/eval_mode.py b/mart/callbacks/eval_mode.py index be3b6397..639444c9 100644 --- a/mart/callbacks/eval_mode.py +++ b/mart/callbacks/eval_mode.py @@ -4,23 +4,47 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + from pytorch_lightning.callbacks import Callback +from mart import utils + +logger = utils.get_pylogger(__name__) + __all__ = ["AttackInEvalMode"] class AttackInEvalMode(Callback): """Switch the model into eval mode during attack.""" - def __init__(self): - self.training_mode_status = None - - def on_train_start(self, trainer, model): - self.training_mode_status = model.training - model.train(False) - - def on_train_end(self, trainer, model): - assert self.training_mode_status is not None - - # Resume the previous training status of the model. - model.train(self.training_mode_status) + def __init__(self, module_classes: type | list[type]): + # FIXME: convert strings to classes using hydra.utils.get_class? This will clean up some verbosity in configuration but will require importing hydra in this callback. + if isinstance(module_classes, type): + module_classes = [module_classes] + + self.module_classes = tuple(module_classes) + + def setup(self, trainer, pl_module, stage): + if stage != "fit": + return + + # Log to the console so the user can see visually see which modules will be in eval mode during training. + for name, module in pl_module.named_modules(): + if isinstance(module, self.module_classes): + logger.info( + f"Setting eval mode for {name} ({module.__class__.__module__}.{module.__class__.__name__})" + ) + + def on_train_epoch_start(self, trainer, pl_module): + # We must use on_train_epoch_start because PL will set pl_module to train mode right before this callback. + # See: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks + for name, module in pl_module.named_modules(): + if isinstance(module, self.module_classes): + module.eval() + + def on_train_epoch_end(self, trainer, pl_module): + # FIXME: Why is this necessary? + for name, module in pl_module.named_modules(): + if isinstance(module, self.module_classes): + module.train() diff --git a/mart/callbacks/no_grad_mode.py b/mart/callbacks/no_grad_mode.py index cfb90ead..4a86d985 100644 --- a/mart/callbacks/no_grad_mode.py +++ b/mart/callbacks/no_grad_mode.py @@ -4,8 +4,15 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + +import torch from pytorch_lightning.callbacks import Callback +from mart import utils + +logger = utils.get_pylogger(__name__) + __all__ = ["ModelParamsNoGrad"] @@ -15,10 +22,25 @@ class ModelParamsNoGrad(Callback): This callback should not change the result. Don't use unless an attack runs faster. """ - def on_train_start(self, trainer, model): - for param in model.parameters(): - param.requires_grad_(False) + def __init__(self, module_names: str | list[str] = None): + if isinstance(module_names, str): + module_names = [module_names] + + self.module_names = module_names + + def setup(self, trainer, pl_module, stage): + if stage != "fit": + return + + # We use setup, and not on_train_start, so that mart.optim.OptimizerFactory can ignore parameters with no gradients. + # See: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks + for name, param in pl_module.named_parameters(): + if any(name.startswith(module_name) for module_name in self.module_names): + logger.info(f"Disabling gradient for {name}") + param.requires_grad_(False) - def on_train_end(self, trainer, model): - for param in model.parameters(): - param.requires_grad_(True) + def teardown(self, trainer, pl_module, stage): + for name, param in pl_module.named_parameters(): + if any(name.startswith(module_name) for module_name in self.module_names): + # FIXME: Why is this necessary? + param.requires_grad_(True) diff --git a/mart/configs/callbacks/attack_in_eval_mode.yaml b/mart/configs/callbacks/attack_in_eval_mode.yaml index 2acdc953..4ca096b0 100644 --- a/mart/configs/callbacks/attack_in_eval_mode.yaml +++ b/mart/configs/callbacks/attack_in_eval_mode.yaml @@ -1,2 +1,11 @@ attack_in_eval_mode: _target_: mart.callbacks.AttackInEvalMode + module_classes: ??? + # - _target_: hydra.utils.get_class + # path: mart.models.LitModular + # - _target_: hydra.utils.get_class + # path: torch.nn.BatchNorm2d + # - _target_: hydra.utils.get_class + # path: torch.nn.Dropout + # - _target_: hydra.utils.get_class + # path: torch.nn.SyncBatchNorm diff --git a/mart/configs/callbacks/no_grad_mode.yaml b/mart/configs/callbacks/no_grad_mode.yaml index 6b4312fd..d12d18e9 100644 --- a/mart/configs/callbacks/no_grad_mode.yaml +++ b/mart/configs/callbacks/no_grad_mode.yaml @@ -1,2 +1,3 @@ -attack_in_eval_mode: +no_grad_mode: _target_: mart.callbacks.ModelParamsNoGrad + module_names: ???