Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mart/attack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .adversary import *
from .adversary_in_art import *
from .adversary_wrapper import *
from .attacker_wrapper import *
from .callbacks import Callback
from .composer import *
from .enforcer import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,54 +10,56 @@

import torch

__all__ = ["NormalizedAdversaryAdapter"]
__all__ = ["NormalizedAttackerAdapter"]


class NormalizedAdversaryAdapter(torch.nn.Module):
class NormalizedAttackerAdapter(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine that this is an Adversary and should probably mimic the flow of what the current Adversary class does. But instead of taking a Trainer/Attacker, it takes an attack and instead of calling fit, it just calls the attack (or directly calls run_standard_evaluation).

"""A wrapper for running external classification adversaries in MART.
External adversaries commonly take input of NCWH-[0,1] and return input_adv in the same format.
External attack algorithms commonly take input of NCWH-[0,1] and return input_adv in the same
format.
"""

def __init__(
self,
adversary: Callable[[Callable], Callable],
enforcer: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None],
attacker: Callable[[Callable], Callable],
):
"""
Args:
adversary (functools.partial): A partial of an adversary object which awaits model.
enforcer (Callable): Enforcing constraints of an adversary.
attacker (functools.partial): A partial of an attacker object which awaits a model.
"""
super().__init__()

self.adversary = adversary
self.enforcer = enforcer
self.attacker = attacker
self.input_adv = None

def forward(
self,
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
model: torch.nn.Module | None = None,
**kwargs,
):

# Shortcut. Input is already updated in the attack loop.
if model is None:
# Return adversarial input if it is already updated in the attack loop.
if self.input_adv is None:
return input
else:
return self.input_adv

def fit(self, *, input, target, model, **kwargs):
# Input NCHW [0,1]; Output logits.
def model_wrapper(x):
output = model(input=x * 255, target=target, model=None, **kwargs)
logits = output["logits"]
return logits

attack = self.adversary(model_wrapper)
attack = self.attacker(model_wrapper)
input_adv = attack(input / 255, target)

# Round to integer, in case of imprecise scaling.
input_adv = (input_adv * 255).round()
self.enforcer(input_adv, input=input, target=target)

# Save to return later in forward().
self.input_adv = input_adv

return input_adv
42 changes: 23 additions & 19 deletions mart/configs/attack/classification_autoattack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,30 @@ defaults:
- enforcer: default
- enforcer/constraints: [lp, pixel_range]

_target_: mart.attack.NormalizedAdversaryAdapter
adversary:
_target_: mart.utils.adapters.PartialInstanceWrapper
partial:
_target_: autoattack.AutoAttack
_partial_: true
# AutoAttack needs to specify device for PyTorch tensors: cpu/cuda
# We can not use ${trainer.accelerator} because the vocabulary is different: cpu/gpu
# device: cpu
norm: Linf
# 8/255
eps: 0.03137254901960784
version: custom
attacks_to_run:
- apgd-dlr
wrapper:
_target_: mart.utils.adapters.CallableAdapter
_partial_: true
redirecting_fn: run_standard_evaluation
_target_: mart.attack.Adversary

enforcer:
constraints:
lp:
eps: 8

attacker:
_target_: mart.attack.NormalizedAttackerAdapter
attacker:
_target_: mart.utils.adapters.PartialInstanceWrapper
partial:
_target_: autoattack.AutoAttack
_partial_: true
# AutoAttack needs to specify device for PyTorch tensors: cpu/cuda
# We can not use ${trainer.accelerator} because the vocabulary is different: cpu/gpu
# device: cpu
norm: Linf
# 8/255
eps: 0.03137254901960784
version: custom
attacks_to_run:
- apgd-dlr
wrapper:
_target_: mart.utils.adapters.CallableAdapter
_partial_: true
redirecting_fn: run_standard_evaluation
2 changes: 1 addition & 1 deletion tests/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_cifar10_cnn_autoattack_experiment(classification_cfg, tmp_path):
"++datamodule.train_dataset.num_classes=10",
"fit=false",
"[email protected]_adv_test=classification_autoattack",
'+model.modules.input_adv_test.adversary.partial.device="cpu"',
'+model.modules.input_adv_test.attacker.attacker.partial.device="cpu"',
"+trainer.limit_test_batches=1",
] + overrides
run_sh_command(command)
Expand Down