diff --git a/mart/attack/__init__.py b/mart/attack/__init__.py index f86d5ad8..19870873 100644 --- a/mart/attack/__init__.py +++ b/mart/attack/__init__.py @@ -1,12 +1,11 @@ from .adversary import * from .adversary_in_art import * from .adversary_wrapper import * -from .callbacks import Callback from .composer import * from .enforcer import * from .gain import * from .gradient_modifier import * from .initializer import * -from .objective import Objective +from .objective import * from .perturber import * from .projector import * diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 7f6d9b29..065a1a65 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -6,13 +6,16 @@ from __future__ import annotations -from collections import OrderedDict +from functools import partial +from itertools import cycle from typing import TYPE_CHECKING, Any, Callable +import pytorch_lightning as pl import torch +from mart.utils import silent + from ..optim import OptimizerFactory -from .callbacks import Callback if TYPE_CHECKING: from .composer import Composer @@ -22,65 +25,11 @@ from .objective import Objective from .perturber import Perturber -__all__ = ["Adversary", "Attacker"] - - -class AttackerCallbackHookMixin(Callback): - """Define event hooks in the Adversary Loop for callbacks.""" - - callbacks = {} - - def on_run_start(self, **kwargs) -> None: - """Prepare the attack loop state.""" - for _name, callback in self.callbacks.items(): - # FIXME: Skip incomplete callback instance. - # Give access of self to callbacks by `adversary=self`. - callback.on_run_start(**kwargs) - - def on_examine_start(self, **kwargs) -> None: - for _name, callback in self.callbacks.items(): - callback.on_examine_start(**kwargs) - - def on_examine_end(self, **kwargs) -> None: - for _name, callback in self.callbacks.items(): - callback.on_examine_end(**kwargs) - - def on_advance_start(self, **kwargs) -> None: - for _name, callback in self.callbacks.items(): - callback.on_advance_start(**kwargs) - - def on_advance_end(self, **kwargs) -> None: - for _name, callback in self.callbacks.items(): - callback.on_advance_end(**kwargs) - - def on_run_end(self, **kwargs) -> None: - for _name, callback in self.callbacks.items(): - callback.on_run_end(**kwargs) - +__all__ = ["Adversary"] -class Attacker(AttackerCallbackHookMixin, torch.nn.Module): - """The attack optimization loop. - This class implements the following loop structure: - - .. code-block:: python - - on_run_start() - - while true: - on_examine_start() - examine() - on_examine_end() - - if not done: - on_advance_start() - advance() - on_advance_end() - else: - break - - on_run_end() - """ +class Adversary(pl.LightningModule): + """An adversary module which generates and applies perturbation to input.""" def __init__( self, @@ -88,269 +37,166 @@ def __init__( perturber: Perturber, composer: Composer, optimizer: OptimizerFactory | Callable[[Any], torch.optim.Optimizer], - max_iters: int, gain: Gain, - objective: Objective | None = None, - callbacks: dict[str, Callback] | None = None, gradient_modifier: GradientModifier | None = None, + objective: Objective | None = None, + enforcer: Enforcer | None = None, + attacker: pl.Trainer | None = None, + **kwargs, ): """_summary_ Args: - perturber (Perturber): A module that stores perturbations. - composer (Composer): A module which composes adversarial examples from input and perturbation. - optimizer (OptimizerFactory | Callable[[Any], torch.optim.Optimizer]): A partial that returns an Optimizer when given params. - max_iters (int): The max number of attack iterations. + perturber (Perturber): A MART Perturber. + composer (Composer): A MART Composer. + optimizer (OptimizerFactory | Callable[[Any], torch.optim.Optimizer]): A MART OptimizerFactory or partial that returns an Optimizer when given params. gain (Gain): An adversarial gain function, which is a differentiable estimate of adversarial objective. - objective (Objective | None): A function for computing adversarial objective, which returns True or False. Optional. - callbacks (dict[str, Callback] | None): A dictionary of callback objects. Optional. + gradient_modifier (GradientModifier): To modify the gradient of perturbation. + objective (Objective): A function for computing adversarial objective, which returns True or False. Optional. + enforcer (Enforcer): A Callable that enforce constraints on the adversarial input. + attacker (Trainer): A PyTorch-Lightning Trainer object used to fit the perturbation. """ super().__init__() # Hide the perturber module in a list, so that perturbation is not exported as a parameter in the model checkpoint. self._perturber = [perturber] self.composer = composer - self.optimizer_fn = optimizer - if not isinstance(self.optimizer_fn, OptimizerFactory): - self.optimizer_fn = OptimizerFactory(self.optimizer_fn) - - self.max_iters = max_iters - self.callbacks = OrderedDict() - - if callbacks is not None: - self.callbacks.update(callbacks) - - self.objective_fn = objective - # self.gain is a tensor. + self.optimizer = optimizer + if not isinstance(self.optimizer, OptimizerFactory): + self.optimizer = OptimizerFactory(self.optimizer) self.gain_fn = gain self.gradient_modifier = gradient_modifier + self.objective_fn = objective + self.enforcer = enforcer + + self._attacker = attacker + + if self._attacker is None: + # Enable attack to be late bound in forward + self._attacker = partial( + pl.Trainer, + num_sanity_val_steps=0, + logger=False, + max_epochs=0, + limit_train_batches=kwargs.pop("max_iters", 10), + callbacks=list(kwargs.pop("callbacks", {}).values()), # dict to list of values + enable_model_summary=False, + enable_checkpointing=False, + # We should disable progress bar in the progress_bar callback config if needed. + enable_progress_bar=True, + # detect_anomaly=True, + ) + + else: + # We feed the same batch to the attack every time so we treat each step as an + # attack iteration. As such, attackers must only run for 1 epoch and must limit + # the number of attack steps via limit_train_batches. + assert self._attacker.max_epochs == 0 + assert self._attacker.limit_train_batches > 0 @property def perturber(self) -> Perturber: # Hide the perturber module in a list, so that perturbation is not exported as a parameter in the model checkpoint. return self._perturber[0] - @property - def done(self) -> bool: - # Reach the max iteration; - if self.cur_iter >= self.max_iters: - return True + def configure_optimizers(self): + return self.optimizer(self.perturber) - # All adv. examples are found; - if hasattr(self, "found") and bool(self.found.all()) is True: - return True + def training_step(self, batch, batch_idx): + # copy batch since we modify it and it is used internally + batch = batch.copy() - # Compatible with models which return None gain when objective is reached. - # TODO: Remove gain==None stopping criteria in all models, - # because the BestPerturbation callback relies on gain to determine which pert is the best. - if self.gain is None: - return True + # We need to evaluate the perturbation against the whole model, so call it normally to get a gain. + model = batch.pop("model") + outputs = model(**batch) - return False + # FIXME: This should really be just `return outputs`. But this might require a new sequence? + # FIXME: Everything below here should live in the model as modules. + # Use CallWith to dispatch **outputs. + gain = self.gain_fn(**outputs) - def on_run_start( - self, - *, - adversary: torch.nn.Module, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module, - **kwargs, - ): - super().on_run_start( - adversary=adversary, input=input, target=target, model=model, **kwargs - ) + # objective_fn is optional, because adversaries may never reach their objective. + if self.objective_fn is not None: + found = self.objective_fn(**outputs) - # FIXME: We should probably just register IterativeAdversary as a callback. - # Set up the optimizer. - self.cur_iter = 0 + # No need to calculate new gradients if adversarial examples are already found. + if len(gain.shape) > 0: + gain = gain[~found] - # param_groups with learning rate and other optim params. - self.perturber.configure_perturbation(input) - self.opt = self.optimizer_fn(self.perturber) + if len(gain.shape) > 0: + gain = gain.sum() - def on_run_end( - self, - *, - adversary: torch.nn.Module, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module, - **kwargs, - ): - super().on_run_end(adversary=adversary, input=input, target=target, model=model, **kwargs) - - # Release optimization resources - del self.opt - - # Disable mixed-precision optimization for attacks, - # since we haven't implemented it yet. - @torch.autocast("cuda", enabled=False) - @torch.autocast("cpu", enabled=False) - def fit( - self, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module, - **kwargs, - ): + return gain - self.on_run_start(adversary=self, input=input, target=target, model=model, **kwargs) - - while True: - try: - self.on_examine_start( - adversary=self, input=input, target=target, model=model, **kwargs - ) - self.examine(input=input, target=target, model=model, **kwargs) - self.on_examine_end( - adversary=self, input=input, target=target, model=model, **kwargs - ) - - # Check the done condition here, so that every update of perturbation is examined. - if not self.done: - self.on_advance_start( - adversary=self, - input=input, - target=target, - model=model, - **kwargs, - ) - self.advance( - input=input, - target=target, - model=model, - **kwargs, - ) - self.on_advance_end( - adversary=self, - input=input, - target=target, - model=model, - **kwargs, - ) - # Update cur_iter at the end so that all hooks get the correct cur_iter. - self.cur_iter += 1 - else: - break - except StopIteration: - break - - self.on_run_end(adversary=self, input=input, target=target, model=model, **kwargs) - - # Make sure we can do autograd. - # Earlier Pytorch Lightning uses no_grad(), but later PL uses inference_mode(): - # https://github.com/Lightning-AI/lightning/pull/12715 - @torch.enable_grad() - @torch.inference_mode(False) - def examine( - self, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module, - **kwargs, + def configure_gradient_clipping( + self, optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None ): - """Examine current perturbation, update self.gain and self.found.""" - - # Clone tensors for autograd, in case it was created in the inference mode. - # FIXME: object detection uses non-pure-tensor data, but it may have cloned somewhere else implicitly? - if isinstance(input, torch.Tensor): - input = input.clone() - if isinstance(target, torch.Tensor): - target = target.clone() - - # Set model as None, because no need to update perturbation. - # Save everything to self.outputs so that callbacks have access to them. - self.outputs = model(input=input, target=target, model=None, **kwargs) - - # Use CallWith to dispatch **outputs. - self.gain = self.gain_fn(**self.outputs) - - # objective_fn is optional, because adversaries may never reach their objective. - if self.objective_fn is not None: - self.found = self.objective_fn(**self.outputs) - if self.gain.shape == torch.Size([]): - # A reduced gain value, not an input-wise gain vector. - self.total_gain = self.gain - else: - # No need to calculate new gradients if adversarial examples are already found. - self.total_gain = self.gain[~self.found].sum() - else: - self.total_gain = self.gain.sum() + # Configuring gradient clipping in pl.Trainer is still useful, so use it. + super().configure_gradient_clipping( + optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm + ) - # Make sure we can do autograd. - @torch.enable_grad() - @torch.inference_mode(False) - def advance( - self, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module, - **kwargs, - ): - """Run one attack iteration.""" + if self.gradient_modifier: + for group in optimizer.param_groups: + self.gradient_modifier(group["params"]) - self.opt.zero_grad() + @silent() + def forward(self, *, model=None, sequence=None, **batch): + batch["model"] = model + batch["sequence"] = sequence - # Do not flip the gain value, because we set maximize=True in optimizer. - self.total_gain.backward() + # Adversary lives within a sequence of model. To signal the adversary should attack, one + # must pass a model to attack when calling the adversary. Since we do not know where the + # Adversary lives inside the model, we also need the remaining sequence to be able to + # get a loss. + if model and sequence: + self._attack(**batch) - if self.gradient_modifier is not None: - for param_group in self.opt.param_groups: - for param in param_group["params"]: - self.gradient_modifier(param) + perturbation = self.perturber(**batch) + input_adv = self.composer(perturbation, **batch) - self.opt.step() + # Enforce constraints after the attack optimization ends. + if model and sequence: + self.enforcer(input_adv, **batch) - def forward( - self, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - **kwargs, - ): - perturbation = self.perturber(input=input, target=target) - output = self.composer(perturbation, input=input, target=target) + return input_adv - return output + def _attack(self, *, input, **batch): + batch["input"] = input + # Configure and reset perturbation for current inputs + self.perturber.configure_perturbation(input) -class Adversary(torch.nn.Module): - """An adversary module which generates and applies perturbation to input.""" + # Attack, aka fit a perturbation, for one epoch by cycling over the same input batch. + # We use Trainer.limit_train_batches to control the number of attack iterations. + self.attacker.fit_loop.max_epochs += 1 + self.attacker.fit(self, train_dataloaders=cycle([batch])) - def __init__(self, *, enforcer: Enforcer, attacker: Attacker | None = None, **kwargs): - """_summary_ + @property + def attacker(self): + if not isinstance(self._attacker, partial): + return self._attacker - Args: - enforcer (Enforcer): A module which checks if adversarial examples satisfy constraints. - attacker (Attacker): A trainer-like object that computes attacks. - """ - super().__init__() + # Convert torch.device to PL accelerator + if self.device.type == "cuda": + accelerator = "gpu" + devices = [self.device.index] - self.enforcer = enforcer - self.attacker = attacker or Attacker(**kwargs) + elif self.device.type == "cpu": + accelerator = "cpu" + devices = None - def forward( - self, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module | None = None, - **kwargs, - ): - # Generate a perturbation only if we have a model. This will update - # the parameters of self.perturber. - if model is not None: - self.attacker.fit(input=input, target=target, model=model, **kwargs) + else: + raise NotImplementedError - # Get perturbation and apply threat model - # The mask projector in perturber may require information from target. - output = self.attacker(input=input, target=target) + self._attacker = self._attacker(accelerator=accelerator, devices=devices) - if model is not None: - # We only enforce constraints after the attack optimization ends. - self.enforcer(output, input=input, target=target) + return self._attacker - return output + def cpu(self): + # PL places the LightningModule back on the CPU after fitting: + # https://github.com/Lightning-AI/lightning/blob/ff5361604b2fd508aa2432babed6844fbe268849/pytorch_lightning/strategies/single_device.py#L96 + # https://github.com/Lightning-AI/lightning/blob/ff5361604b2fd508aa2432babed6844fbe268849/pytorch_lightning/strategies/ddp.py#L482 + # This is a problem when this LightningModule has parameters, so we stop this from + # happening by ignoring the call to cpu(). + pass diff --git a/mart/attack/callbacks/base.py b/mart/attack/callbacks/base.py deleted file mode 100644 index d820f69b..00000000 --- a/mart/attack/callbacks/base.py +++ /dev/null @@ -1,87 +0,0 @@ -# -# Copyright (C) 2022 Intel Corporation -# -# SPDX-License-Identifier: BSD-3-Clause -# - -from __future__ import annotations - -import abc -from typing import TYPE_CHECKING, Any, Iterable - -import torch - -if TYPE_CHECKING: - from ..adversary import Adversary - -__all__ = ["Callback"] - - -class Callback(abc.ABC): - """Abstract base class of callbacks.""" - - def on_run_start( - self, - *, - adversary: Adversary, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], - model: torch.nn.Module, - **kwargs, - ): - pass - - def on_examine_start( - self, - *, - adversary: Adversary, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], - model: torch.nn.Module, - **kwargs, - ): - pass - - def on_examine_end( - self, - *, - adversary: Adversary, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], - model: torch.nn.Module, - **kwargs, - ): - pass - - def on_advance_start( - self, - *, - adversary: Adversary, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], - model: torch.nn.Module, - **kwargs, - ): - pass - - def on_advance_end( - self, - *, - adversary: Adversary, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], - model: torch.nn.Module, - **kwargs, - ): - pass - - def on_run_end( - self, - *, - adversary: Adversary, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], - model: torch.nn.Module, - **kwargs, - ): - pass diff --git a/mart/attack/callbacks/progress_bar.py b/mart/attack/callbacks/progress_bar.py deleted file mode 100644 index d175aa5d..00000000 --- a/mart/attack/callbacks/progress_bar.py +++ /dev/null @@ -1,33 +0,0 @@ -# -# Copyright (C) 2022 Intel Corporation -# -# SPDX-License-Identifier: BSD-3-Clause -# - -import tqdm - -from .base import Callback - -__all__ = ["ProgressBar"] - - -class ProgressBar(Callback): - """Display progress bar of attack iterations with the gain value.""" - - def on_run_start(self, *, adversary, **kwargs): - self.pbar = tqdm.tqdm(total=adversary.max_iters, leave=False, desc="Attack", unit="iter") - - def on_examine_end(self, *, input, adversary, **kwargs): - msg = "" - if hasattr(adversary, "found"): - # there is no adversary.found if adversary.objective_fn() is not defined. - msg += f"found={int(sum(adversary.found))}/{len(input)}, " - - msg += f"avg_gain={float(adversary.gain.mean()):.2f}, " - - self.pbar.set_description(msg) - self.pbar.update(1) - - def on_run_end(self, **kwargs): - self.pbar.close() - del self.pbar diff --git a/mart/attack/callbacks/__init__.py b/mart/callbacks/__init__.py similarity index 84% rename from mart/attack/callbacks/__init__.py rename to mart/callbacks/__init__.py index 736f7dd1..7ce8b2cf 100644 --- a/mart/attack/callbacks/__init__.py +++ b/mart/callbacks/__init__.py @@ -1,4 +1,3 @@ -from .base import * from .eval_mode import * from .no_grad_mode import * from .progress_bar import * diff --git a/mart/attack/callbacks/eval_mode.py b/mart/callbacks/eval_mode.py similarity index 78% rename from mart/attack/callbacks/eval_mode.py rename to mart/callbacks/eval_mode.py index de5eef75..be3b6397 100644 --- a/mart/attack/callbacks/eval_mode.py +++ b/mart/callbacks/eval_mode.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from .base import Callback +from pytorch_lightning.callbacks import Callback __all__ = ["AttackInEvalMode"] @@ -15,11 +15,11 @@ class AttackInEvalMode(Callback): def __init__(self): self.training_mode_status = None - def on_run_start(self, *, model, **kwargs): + def on_train_start(self, trainer, model): self.training_mode_status = model.training model.train(False) - def on_run_end(self, *, model, **kwargs): + def on_train_end(self, trainer, model): assert self.training_mode_status is not None # Resume the previous training status of the model. diff --git a/mart/attack/callbacks/no_grad_mode.py b/mart/callbacks/no_grad_mode.py similarity index 77% rename from mart/attack/callbacks/no_grad_mode.py rename to mart/callbacks/no_grad_mode.py index bca4d971..cfb90ead 100644 --- a/mart/attack/callbacks/no_grad_mode.py +++ b/mart/callbacks/no_grad_mode.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from .base import Callback +from pytorch_lightning.callbacks import Callback __all__ = ["ModelParamsNoGrad"] @@ -15,10 +15,10 @@ class ModelParamsNoGrad(Callback): This callback should not change the result. Don't use unless an attack runs faster. """ - def on_run_start(self, *, model, **kwargs): + def on_train_start(self, trainer, model): for param in model.parameters(): param.requires_grad_(False) - def on_run_end(self, *, model, **kwargs): + def on_train_end(self, trainer, model): for param in model.parameters(): param.requires_grad_(True) diff --git a/mart/callbacks/progress_bar.py b/mart/callbacks/progress_bar.py new file mode 100644 index 00000000..f33811d7 --- /dev/null +++ b/mart/callbacks/progress_bar.py @@ -0,0 +1,58 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from typing import Any + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import TQDMProgressBar +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +__all__ = ["ProgressBar"] + + +class ProgressBar(TQDMProgressBar): + """Display progress bar of attack iterations with the gain value.""" + + def __init__(self, *args, disable=False, rename_metrics=None, **kwargs): + if "process_position" not in kwargs: + # Automatically place the progress bar by rank if position is not specified. + # rank starts with 0 + rank_id = rank_zero_only.rank + # Adversary progress bars start at position 1, because the main progress bar takes position 0. + process_position = rank_id + 1 + kwargs["process_position"] = process_position + + super().__init__(*args, **kwargs) + + if disable: + self.disable() + + # E.g. rename loss as gain for adversary's progress bar. + self.rename_metrics = rename_metrics or {} + + def init_train_tqdm(self): + bar = super().init_train_tqdm() + bar.leave = False + bar.set_description("Attack") + bar.unit = "iter" + + return bar + + def on_train_epoch_start(self, trainer: pl.Trainer, *_: Any) -> None: + super().on_train_epoch_start(trainer) + + # So that it does not display negative rate. + self.main_progress_bar.initial = 0 + # So that it does not display Epoch n. + rank_id = rank_zero_only.rank + self.main_progress_bar.set_description(f"Attack@rank{rank_id}") + + def get_metrics(self, *args, **kwargs): + """Rename metrics on progress bar status.""" + metrics = super().get_metrics(*args, **kwargs) + for old_name, new_name in self.rename_metrics.items(): + metrics[new_name] = metrics.pop(old_name) + return metrics diff --git a/mart/attack/callbacks/visualizer.py b/mart/callbacks/visualizer.py similarity index 52% rename from mart/attack/callbacks/visualizer.py rename to mart/callbacks/visualizer.py index d0eb0c58..3354321e 100644 --- a/mart/attack/callbacks/visualizer.py +++ b/mart/callbacks/visualizer.py @@ -6,10 +6,9 @@ import os +from pytorch_lightning.callbacks import Callback from torchvision.transforms import ToPILImage -from .base import Callback - __all__ = ["PerturbedImageVisualizer"] @@ -19,16 +18,23 @@ class PerturbedImageVisualizer(Callback): def __init__(self, folder): super().__init__() + # FIXME: This should use the Trainer's logging directory. self.folder = folder self.convert = ToPILImage() if not os.path.isdir(self.folder): os.makedirs(self.folder) - def on_run_end(self, *, adversary, input, target, model, **kwargs): - adv_input = adversary(input=input, target=target, model=None, **kwargs) + def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): + # Save input and target for on_train_end + self.input = batch["input"] + self.target = batch["target"] + + def on_train_end(self, trainer, model): + # FIXME: We should really just save this to outputs instead of recomputing adv_input + adv_input = model(input=self.input, target=self.target) - for img, tgt in zip(adv_input, target): + for img, tgt in zip(adv_input, self.target): fname = tgt["file_name"] fpath = os.path.join(self.folder, fname) im = self.convert(img / 255) diff --git a/mart/configs/attack/adversary.yaml b/mart/configs/attack/adversary.yaml index 1ec03b4d..40188b5a 100644 --- a/mart/configs/attack/adversary.yaml +++ b/mart/configs/attack/adversary.yaml @@ -1,3 +1,6 @@ +defaults: + - /callbacks@callbacks: [progress_bar] + _target_: mart.attack.Adversary perturber: ??? composer: ??? diff --git a/mart/configs/attack/callbacks/attack_in_eval_mode.yaml b/mart/configs/attack/callbacks/attack_in_eval_mode.yaml deleted file mode 100644 index 15768e22..00000000 --- a/mart/configs/attack/callbacks/attack_in_eval_mode.yaml +++ /dev/null @@ -1,2 +0,0 @@ -attack_in_eval_mode: - _target_: mart.attack.callbacks.AttackInEvalMode diff --git a/mart/configs/attack/callbacks/no_grad_mode.yaml b/mart/configs/attack/callbacks/no_grad_mode.yaml deleted file mode 100644 index c94b9597..00000000 --- a/mart/configs/attack/callbacks/no_grad_mode.yaml +++ /dev/null @@ -1,2 +0,0 @@ -attack_in_eval_mode: - _target_: mart.attack.callbacks.ModelParamsNoGrad diff --git a/mart/configs/attack/callbacks/progress_bar.yaml b/mart/configs/attack/callbacks/progress_bar.yaml deleted file mode 100644 index 21d4c477..00000000 --- a/mart/configs/attack/callbacks/progress_bar.yaml +++ /dev/null @@ -1,2 +0,0 @@ -progress_bar: - _target_: mart.attack.callbacks.ProgressBar diff --git a/mart/configs/attack/classification_eps1.75_fgsm.yaml b/mart/configs/attack/classification_eps1.75_fgsm.yaml index dc3d2cb2..9ce4708c 100644 --- a/mart/configs/attack/classification_eps1.75_fgsm.yaml +++ b/mart/configs/attack/classification_eps1.75_fgsm.yaml @@ -27,3 +27,8 @@ perturber: projector: eps: 1.75 + +# We can turn off progress bar for one-step attack. +callbacks: + progress_bar: + disable: true diff --git a/mart/configs/attack/object_detection_mask_adversary.yaml b/mart/configs/attack/object_detection_mask_adversary.yaml index 19ff5659..ad99dda0 100644 --- a/mart/configs/attack/object_detection_mask_adversary.yaml +++ b/mart/configs/attack/object_detection_mask_adversary.yaml @@ -8,7 +8,6 @@ defaults: - gain: rcnn_training_loss - gradient_modifier: sign - objective: zero_ap - - callbacks: [image_visualizer] - enforcer: default - enforcer/constraints: [mask, pixel_range] diff --git a/mart/configs/callbacks/attack_in_eval_mode.yaml b/mart/configs/callbacks/attack_in_eval_mode.yaml new file mode 100644 index 00000000..2acdc953 --- /dev/null +++ b/mart/configs/callbacks/attack_in_eval_mode.yaml @@ -0,0 +1,2 @@ +attack_in_eval_mode: + _target_: mart.callbacks.AttackInEvalMode diff --git a/mart/configs/callbacks/default.yaml b/mart/configs/callbacks/default.yaml index 5df27bfd..abdfa8b2 100644 --- a/mart/configs/callbacks/default.yaml +++ b/mart/configs/callbacks/default.yaml @@ -1,8 +1,7 @@ defaults: - - model_checkpoint.yaml - - early_stopping.yaml - - model_summary.yaml - - rich_progress_bar.yaml + - model_checkpoint + - model_summary + - rich_progress_bar - _self_ model_checkpoint: @@ -13,10 +12,5 @@ model_checkpoint: save_last: True auto_insert_metric_name: False -early_stopping: - monitor: "val/acc" - patience: 100 - mode: "max" - model_summary: max_depth: -1 diff --git a/mart/configs/attack/callbacks/image_visualizer.yaml b/mart/configs/callbacks/image_visualizer.yaml similarity index 53% rename from mart/configs/attack/callbacks/image_visualizer.yaml rename to mart/configs/callbacks/image_visualizer.yaml index a75b6db2..65b9f8dd 100644 --- a/mart/configs/attack/callbacks/image_visualizer.yaml +++ b/mart/configs/callbacks/image_visualizer.yaml @@ -1,3 +1,3 @@ image_visualizer: - _target_: mart.attack.callbacks.PerturbedImageVisualizer + _target_: mart.callbacks.PerturbedImageVisualizer folder: ${paths.output_dir}/adversarial_examples diff --git a/mart/configs/callbacks/no_grad_mode.yaml b/mart/configs/callbacks/no_grad_mode.yaml new file mode 100644 index 00000000..6b4312fd --- /dev/null +++ b/mart/configs/callbacks/no_grad_mode.yaml @@ -0,0 +1,2 @@ +attack_in_eval_mode: + _target_: mart.callbacks.ModelParamsNoGrad diff --git a/mart/configs/callbacks/progress_bar.yaml b/mart/configs/callbacks/progress_bar.yaml new file mode 100644 index 00000000..61be62f2 --- /dev/null +++ b/mart/configs/callbacks/progress_bar.yaml @@ -0,0 +1,6 @@ +progress_bar: + _target_: mart.callbacks.ProgressBar + # Enable progress bar for adversary by default. + disable: false + rename_metrics: + loss: gain diff --git a/mart/utils/__init__.py b/mart/utils/__init__.py index 91c84339..50e71b3d 100644 --- a/mart/utils/__init__.py +++ b/mart/utils/__init__.py @@ -3,4 +3,5 @@ from .monkey_patch import * from .pylogger import * from .rich_utils import * +from .silent import * from .utils import * diff --git a/mart/utils/silent.py b/mart/utils/silent.py new file mode 100644 index 00000000..b9cbd1c3 --- /dev/null +++ b/mart/utils/silent.py @@ -0,0 +1,30 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +import logging +from contextlib import ContextDecorator + +__all__ = ["silent"] + + +class silent(ContextDecorator): + """Suppress logging.""" + + DEFAULT_NAMES = ["pytorch_lightning.utilities.rank_zero", "pytorch_lightning.accelerators.gpu"] + + def __init__(self, names=None): + if names is None: + names = silent.DEFAULT_NAMES + + self.loggers = [logging.getLogger(name) for name in names] + + def __enter__(self): + for logger in self.loggers: + logger.propagate = False + + def __exit__(self, exc_type, exc_value, traceback): + for logger in self.loggers: + logger.propagate = False diff --git a/tests/test_adversary.py b/tests/test_adversary.py index fe91ba5e..0d0777ae 100644 --- a/tests/test_adversary.py +++ b/tests/test_adversary.py @@ -19,9 +19,10 @@ def test_adversary(input_data, target_data, perturbation): perturber = Mock(spec=Perturber, return_value=perturbation) - composer = Mock(sepc=Composer, return_value=input_data + perturbation) + composer = mart.attack.composer.Additive() gain = Mock() enforcer = Mock() + attacker = Mock(max_epochs=0, limit_train_batches=1, fit_loop=Mock(max_epochs=0)) adversary = Adversary( perturber=perturber, @@ -29,54 +30,57 @@ def test_adversary(input_data, target_data, perturbation): optimizer=None, gain=gain, enforcer=enforcer, - max_iters=1, + attacker=attacker, ) output_data = adversary(input=input_data, target=target_data) # The enforcer and attacker should only be called when model is not None. + enforcer.assert_not_called() + attacker.fit.assert_not_called() + assert attacker.fit_loop.max_epochs == 0 + perturber.assert_called_once() gain.assert_not_called() - enforcer.assert_not_called() torch.testing.assert_close(output_data, input_data + perturbation) def test_with_model(input_data, target_data, perturbation): perturber = Mock(spec=Perturber, return_value=perturbation) - composer = Mock(sepc=Composer, return_value=input_data + perturbation) + composer = mart.attack.composer.Additive() gain = Mock() enforcer = Mock() - model = Mock(return_value={"loss": 0}) + attacker = Mock(max_epochs=0, limit_train_batches=1, fit_loop=Mock(max_epochs=0)) + model = Mock() sequence = Mock() - optimizer = Mock() - optimizer_fn = Mock(spec=mart.optim.OptimizerFactory, return_value=optimizer) adversary = Adversary( perturber=perturber, composer=composer, - optimizer=optimizer_fn, + optimizer=None, gain=gain, enforcer=enforcer, - max_iters=1, + attacker=attacker, ) output_data = adversary(input=input_data, target=target_data, model=model, sequence=sequence) # The enforcer is only called when model is not None. enforcer.assert_called_once() + attacker.fit.assert_called_once() # Once with model=None to get perturbation. # When model=model, configure_perturbation() should be called. perturber.assert_called_once() - assert gain.call_count == 2 # examine is called before done + gain.assert_not_called() # we mock attacker so this shouldn't be called torch.testing.assert_close(output_data, input_data + perturbation) def test_hidden_params(input_data, target_data, perturbation): initializer = Mock() - composer = Mock() + composer = mart.attack.composer.Additive() projector = Mock() perturber = Perturber(initializer=initializer, projector=projector) @@ -109,25 +113,24 @@ def test_hidden_params(input_data, target_data, perturbation): def test_hidden_params_after_forward(input_data, target_data, perturbation): initializer = Mock() - composer = Mock() + composer = mart.attack.composer.Additive() projector = Mock() perturber = Perturber(initializer=initializer, projector=projector) gain = Mock() enforcer = Mock() - model = Mock(return_value={"loss": 0}) + attacker = Mock(max_epochs=0, limit_train_batches=1, fit_loop=Mock(max_epochs=0)) + model = Mock() sequence = Mock() - optimizer = Mock() - optimizer_fn = Mock(return_value=optimizer) adversary = Adversary( perturber=perturber, composer=composer, - optimizer=optimizer_fn, + optimizer=None, gain=gain, enforcer=enforcer, - max_iters=1, + attacker=attacker, ) output_data = adversary(input=input_data, target=target_data, model=model, sequence=sequence) @@ -143,21 +146,20 @@ def test_hidden_params_after_forward(input_data, target_data, perturbation): def test_perturbation(input_data, target_data, perturbation): perturber = Mock(spec=Perturber, return_value=perturbation) - composer = Mock(spec=Composer, return_value=perturbation + input_data) + composer = mart.attack.composer.Additive() gain = Mock() enforcer = Mock() - model = Mock(return_value={"loss": 0}) + attacker = Mock(max_epochs=0, limit_train_batches=1, fit_loop=Mock(max_epochs=0)) + model = Mock() sequence = Mock() - optimizer = Mock() - optimizer_fn = Mock(spec=mart.optim.OptimizerFactory, return_value=optimizer) adversary = Adversary( perturber=perturber, composer=composer, - optimizer=optimizer_fn, + optimizer=None, gain=gain, enforcer=enforcer, - max_iters=1, + attacker=attacker, ) _ = adversary(input=input_data, target=target_data, model=model, sequence=sequence) @@ -165,9 +167,9 @@ def test_perturbation(input_data, target_data, perturbation): # The enforcer is only called when model is not None. enforcer.assert_called_once() + attacker.fit.assert_called_once() # Once with model and sequence and once without - perturber.configure_perturbation.assert_called_once() assert perturber.call_count == 2 torch.testing.assert_close(output_data, input_data + perturbation) @@ -216,3 +218,115 @@ def model(input, target, model=None, **kwargs): perturbation = input_data - input_adv torch.testing.assert_close(perturbation.unique(), torch.Tensor([-1, 0, 1])) + + +def test_configure_optimizers(input_data, target_data): + perturber = Mock() + composer = mart.attack.composer.Additive() + optimizer = Mock(spec=mart.optim.OptimizerFactory) + gain = Mock() + + adversary = Adversary( + perturber=perturber, + composer=composer, + optimizer=optimizer, + gain=gain, + ) + + adversary.configure_optimizers() + + assert optimizer.call_count == 1 + gain.assert_not_called() + + +def test_training_step(input_data, target_data): + perturber = Mock() + composer = mart.attack.composer.Additive() + optimizer = Mock(spec=mart.optim.OptimizerFactory) + gain = Mock(return_value=torch.tensor(1337)) + model = Mock(return_value={}) + + adversary = Adversary( + perturber=perturber, + composer=composer, + optimizer=optimizer, + gain=gain, + ) + + output = adversary.training_step( + {"input": input_data, "target": target_data, "model": model}, 0 + ) + + gain.assert_called_once() + assert output == 1337 + + +def test_training_step_with_many_gain(input_data, target_data): + perturber = Mock() + composer = mart.attack.composer.Additive() + optimizer = Mock(spec=mart.optim.OptimizerFactory) + gain = Mock(return_value=torch.tensor([1234, 5678])) + model = Mock(return_value={}) + + adversary = Adversary( + perturber=perturber, + composer=composer, + optimizer=optimizer, + gain=gain, + ) + + output = adversary.training_step( + {"input": input_data, "target": target_data, "model": model}, 0 + ) + + assert output == 1234 + 5678 + + +def test_training_step_with_objective(input_data, target_data): + perturber = Mock() + composer = mart.attack.composer.Additive() + optimizer = Mock(spec=mart.optim.OptimizerFactory) + gain = Mock(return_value=torch.tensor([1234, 5678])) + model = Mock(return_value={}) + objective = Mock(return_value=torch.tensor([True, False], dtype=torch.bool)) + + adversary = Adversary( + perturber=perturber, + composer=composer, + optimizer=optimizer, + objective=objective, + gain=gain, + ) + + output = adversary.training_step( + {"input": input_data, "target": target_data, "model": model}, 0 + ) + + assert output == 5678 + + objective.assert_called_once() + + +def test_configure_gradient_clipping(): + perturber = Mock() + composer = mart.attack.composer.Additive() + optimizer = Mock( + spec=mart.optim.OptimizerFactory, param_groups=[{"params": Mock()}, {"params": Mock()}] + ) + gradient_modifier = Mock() + gain = Mock() + + adversary = Adversary( + perturber=perturber, + composer=composer, + optimizer=optimizer, + gradient_modifier=gradient_modifier, + gain=gain, + ) + # We need to mock a trainer since LightningModule does some checks + adversary.trainer = Mock(gradient_clip_val=1.0, gradient_clip_algorithm="norm") + + adversary.configure_gradient_clipping(optimizer, 0) + + # Once for each parameter in the optimizer + assert gradient_modifier.call_count == 2 diff --git a/tests/test_visualizer.py b/tests/test_visualizer.py index 5a269db2..5c25e930 100644 --- a/tests/test_visualizer.py +++ b/tests/test_visualizer.py @@ -10,7 +10,7 @@ from torchvision.transforms import ToPILImage from mart.attack import Adversary -from mart.attack.callbacks import PerturbedImageVisualizer +from mart.callbacks import PerturbedImageVisualizer def test_visualizer_run_end(input_data, target_data, perturbation, tmp_path): @@ -19,15 +19,19 @@ def test_visualizer_run_end(input_data, target_data, perturbation, tmp_path): target_list = [target_data] # simulate an addition perturbation - def perturb(input, target, model): + def perturb(input): result = [sample + perturbation for sample in input] return result - model = Mock() + trainer = Mock() + model = Mock(return_value=perturb(input_list)) + outputs = Mock() + batch = {"input": input_list, "target": target_list} adversary = Mock(spec=Adversary, side_effect=perturb) visualizer = PerturbedImageVisualizer(folder) - visualizer.on_run_end(adversary=adversary, input=input_list, target=target_list, model=model) + visualizer.on_train_batch_end(trainer, model, outputs, batch, 0) + visualizer.on_train_end(trainer, model) # verify that the visualizer created the JPG file expected_output_path = folder / target_data["file_name"]