Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions mart/attack/adversary_in_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause
#

from typing import Any, List, Optional
from typing import Any, Iterable, List, Optional

import hydra
import numpy
Expand Down Expand Up @@ -82,17 +82,18 @@ def convert_input_art_to_mart(self, x: numpy.ndarray):
x (np.ndarray): NHWC, [0, 1]

Returns:
tuple: a tuple of tensors in CHW, [0, 255].
Iterable[torch.Tensor]: an Iterable of tensors in CHW, [0, 255].
"""
input = torch.tensor(x).permute((0, 3, 1, 2)).to(self._device) * 255
# FIXME: replace tuple with whatever input's type is
input = tuple(inp_ for inp_ in input)
return input

def convert_input_mart_to_art(self, input: tuple):
def convert_input_mart_to_art(self, input: Iterable[torch.Tensor]):
"""Convert MART input to the ART's format.

Args:
input (tuple): a tuple of tensors in CHW, [0, 255].
input (Iterable[torch.Tensor]): an Iterable of tensors in CHW, [0, 255].

Returns:
np.ndarray: NHWC, [0, 1]
Expand All @@ -112,7 +113,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List):
y_patch_metadata (_type_): _description_

Returns:
tuple: a tuple of target dictionaies.
Iterable[dict[str, Any]]: an Iterable of target dictionaies.
"""
# Copy y to target, and convert ndarray to pytorch tensors accordingly.
target = []
Expand All @@ -132,6 +133,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List):
target_i["file_name"] = f"{yi['image_id'][0]}.jpg"
target.append(target_i)

# FIXME: replace tuple with input type?
target = tuple(target)

return target
11 changes: 7 additions & 4 deletions mart/attack/adversary_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

from __future__ import annotations

from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Iterable

import torch

if TYPE_CHECKING:
from .enforcer import Enforcer

__all__ = ["NormalizedAdversaryAdapter"]


Expand All @@ -22,7 +25,7 @@ class NormalizedAdversaryAdapter(torch.nn.Module):
def __init__(
self,
adversary: Callable[[Callable], Callable],
enforcer: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None],
enforcer: Enforcer,
):
"""

Expand All @@ -37,8 +40,8 @@ def __init__(

def forward(
self,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module | None = None,
**kwargs,
):
Expand Down
26 changes: 13 additions & 13 deletions mart/attack/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable

import torch

Expand All @@ -24,8 +24,8 @@ def on_run_start(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -35,8 +35,8 @@ def on_examine_start(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -46,8 +46,8 @@ def on_examine_end(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -57,8 +57,8 @@ def on_advance_start(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -68,8 +68,8 @@ def on_advance_end(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -79,8 +79,8 @@ def on_run_end(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand Down
27 changes: 17 additions & 10 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,36 @@
from __future__ import annotations

import abc
from typing import Any
from typing import Any, Iterable

import torch


class Composer(abc.ABC):
def __call__(
self,
perturbation: torch.Tensor | tuple,
perturbation: torch.Tensor | Iterable[torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
**kwargs,
) -> torch.Tensor | tuple:
if isinstance(perturbation, tuple):
input_adv = tuple(
) -> torch.Tensor | Iterable[torch.Tensor]:
if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor):
return self.compose(perturbation, input=input, target=target)

elif (
isinstance(perturbation, Iterable)
and isinstance(input, Iterable) # noqa: W503
and isinstance(target, Iterable) # noqa: W503
):
# FIXME: replace tuple with whatever input's type is
return tuple(
self.compose(perturbation_i, input=input_i, target=target_i)
for perturbation_i, input_i, target_i in zip(perturbation, input, target)
)
else:
input_adv = self.compose(perturbation, input=input, target=target)

return input_adv
else:
raise NotImplementedError

@abc.abstractmethod
def compose(
Expand Down
57 changes: 24 additions & 33 deletions mart/attack/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

import abc
from typing import Any
from typing import Any, Iterable

import torch

Expand Down Expand Up @@ -95,45 +95,36 @@ def verify(self, input_adv, *, input, target):


class Enforcer:
def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None:
self.modality_constraints = modality_constraints
def __init__(self, constraints: dict[str, Constraint]) -> None:
self.constraints = list(constraints.values()) # intentionally ignore keys

@torch.no_grad()
def _enforce(
def __call__(
self,
input_adv: torch.Tensor,
input_adv: torch.Tensor | Iterable[torch.Tensor],
*,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
modality: str,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
**kwargs,
):
for constraint in self.modality_constraints[modality].values():
constraint(input_adv, input=input, target=target)
if isinstance(input_adv, torch.Tensor) and isinstance(input, torch.Tensor):
self.enforce(input_adv, input=input, target=target)

elif (
isinstance(input_adv, Iterable)
and isinstance(input, Iterable) # noqa: W503
and isinstance(target, Iterable) # noqa: W503
):
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self.enforce(input_adv_i, input=input_i, target=target_i)

def __call__(
@torch.no_grad()
def enforce(
self,
input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
input_adv: torch.Tensor,
*,
input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
modality: str = "constraints",
**kwargs,
):
assert type(input_adv) == type(input)

if isinstance(input_adv, torch.Tensor):
# Finally we can verify constraints on tensor, per its modality.
# Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities.
self._enforce(input_adv, input=input, target=target, modality=modality)
elif isinstance(input_adv, dict):
# The dict input has modalities specified in keys, passing them recursively.
for modality in input_adv:
self(input_adv[modality], input=input[modality], target=target, modality=modality)
elif isinstance(input_adv, (list, tuple)):
# We assume a modality-dictionary only contains tensors, but not list/tuple.
assert modality == "constraints"
# The list or tuple input is a collection of sub-input and sub-target.
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self(input_adv_i, input=input_i, target=target_i, modality=modality)
else:
raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.")
for constraint in self.constraints:
constraint(input_adv, input=input, target=target)
36 changes: 16 additions & 20 deletions mart/attack/gradient_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,40 @@

from __future__ import annotations

import abc
from typing import Iterable

import torch

__all__ = ["GradientModifier"]


class GradientModifier(abc.ABC):
class GradientModifier:
"""Gradient modifier base class."""

def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
pass


class Sign(GradientModifier):
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]

parameters = [p for p in parameters if p.grad is not None]
[self.modify_(parameter) for parameter in parameters]

@torch.no_grad()
def modify_(self, parameter: torch.Tensor) -> None:
pass


for p in parameters:
p.grad.detach().sign_()
class Sign(GradientModifier):
@torch.no_grad()
def modify_(self, parameter: torch.Tensor) -> None:
parameter.grad.sign_()


class LpNormalizer(GradientModifier):
"""Scale gradients by a certain L-p norm."""

def __init__(self, p: int | float):
self.p = p

def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]

parameters = [p for p in parameters if p.grad is not None]
self.p = float(p)

for p in parameters:
p_norm = torch.norm(p.grad.detach(), p=self.p)
p.grad.detach().div_(p_norm)
@torch.no_grad()
def modify_(self, parameter: torch.Tensor) -> None:
p_norm = torch.norm(parameter.grad.detach(), p=self.p)
parameter.grad.detach().div_(p_norm)
Loading