Skip to content

Commit 4338e08

Browse files
authored
Replace tuple with Iterable[torch.Tensor] (#134)
* Replace tuple with Iterable[torch.Tensor] * Make GradientModifier accept Iterable[torch.Tensor] * Fix annotations
1 parent 20d2078 commit 4338e08

File tree

10 files changed

+168
-150
lines changed

10 files changed

+168
-150
lines changed

mart/attack/adversary_in_art.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66

7-
from typing import Any, List, Optional
7+
from typing import Any, Iterable, List, Optional
88

99
import hydra
1010
import numpy
@@ -82,17 +82,18 @@ def convert_input_art_to_mart(self, x: numpy.ndarray):
8282
x (np.ndarray): NHWC, [0, 1]
8383
8484
Returns:
85-
tuple: a tuple of tensors in CHW, [0, 255].
85+
Iterable[torch.Tensor]: an Iterable of tensors in CHW, [0, 255].
8686
"""
8787
input = torch.tensor(x).permute((0, 3, 1, 2)).to(self._device) * 255
88+
# FIXME: replace tuple with whatever input's type is
8889
input = tuple(inp_ for inp_ in input)
8990
return input
9091

91-
def convert_input_mart_to_art(self, input: tuple):
92+
def convert_input_mart_to_art(self, input: Iterable[torch.Tensor]):
9293
"""Convert MART input to the ART's format.
9394
9495
Args:
95-
input (tuple): a tuple of tensors in CHW, [0, 255].
96+
input (Iterable[torch.Tensor]): an Iterable of tensors in CHW, [0, 255].
9697
9798
Returns:
9899
np.ndarray: NHWC, [0, 1]
@@ -112,7 +113,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List):
112113
y_patch_metadata (_type_): _description_
113114
114115
Returns:
115-
tuple: a tuple of target dictionaies.
116+
Iterable[dict[str, Any]]: an Iterable of target dictionaies.
116117
"""
117118
# Copy y to target, and convert ndarray to pytorch tensors accordingly.
118119
target = []
@@ -132,6 +133,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List):
132133
target_i["file_name"] = f"{yi['image_id'][0]}.jpg"
133134
target.append(target_i)
134135

136+
# FIXME: replace tuple with input type?
135137
target = tuple(target)
136138

137139
return target

mart/attack/adversary_wrapper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66

77
from __future__ import annotations
88

9-
from typing import Any, Callable
9+
from typing import TYPE_CHECKING, Any, Callable, Iterable
1010

1111
import torch
1212

13+
if TYPE_CHECKING:
14+
from .enforcer import Enforcer
15+
1316
__all__ = ["NormalizedAdversaryAdapter"]
1417

1518

@@ -22,7 +25,7 @@ class NormalizedAdversaryAdapter(torch.nn.Module):
2225
def __init__(
2326
self,
2427
adversary: Callable[[Callable], Callable],
25-
enforcer: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None],
28+
enforcer: Enforcer,
2629
):
2730
"""
2831
@@ -37,8 +40,8 @@ def __init__(
3740

3841
def forward(
3942
self,
40-
input: torch.Tensor | tuple,
41-
target: torch.Tensor | dict[str, Any] | tuple,
43+
input: torch.Tensor | Iterable[torch.Tensor],
44+
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
4245
model: torch.nn.Module | None = None,
4346
**kwargs,
4447
):

mart/attack/callbacks/base.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
import abc
10-
from typing import TYPE_CHECKING, Any
10+
from typing import TYPE_CHECKING, Any, Iterable
1111

1212
import torch
1313

@@ -24,8 +24,8 @@ def on_run_start(
2424
self,
2525
*,
2626
adversary: Adversary,
27-
input: torch.Tensor | tuple,
28-
target: torch.Tensor | dict[str, Any] | tuple,
27+
input: torch.Tensor | Iterable[torch.Tensor],
28+
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
2929
model: torch.nn.Module,
3030
**kwargs,
3131
):
@@ -35,8 +35,8 @@ def on_examine_start(
3535
self,
3636
*,
3737
adversary: Adversary,
38-
input: torch.Tensor | tuple,
39-
target: torch.Tensor | dict[str, Any] | tuple,
38+
input: torch.Tensor | Iterable[torch.Tensor],
39+
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
4040
model: torch.nn.Module,
4141
**kwargs,
4242
):
@@ -46,8 +46,8 @@ def on_examine_end(
4646
self,
4747
*,
4848
adversary: Adversary,
49-
input: torch.Tensor | tuple,
50-
target: torch.Tensor | dict[str, Any] | tuple,
49+
input: torch.Tensor | Iterable[torch.Tensor],
50+
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
5151
model: torch.nn.Module,
5252
**kwargs,
5353
):
@@ -57,8 +57,8 @@ def on_advance_start(
5757
self,
5858
*,
5959
adversary: Adversary,
60-
input: torch.Tensor | tuple,
61-
target: torch.Tensor | dict[str, Any] | tuple,
60+
input: torch.Tensor | Iterable[torch.Tensor],
61+
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
6262
model: torch.nn.Module,
6363
**kwargs,
6464
):
@@ -68,8 +68,8 @@ def on_advance_end(
6868
self,
6969
*,
7070
adversary: Adversary,
71-
input: torch.Tensor | tuple,
72-
target: torch.Tensor | dict[str, Any] | tuple,
71+
input: torch.Tensor | Iterable[torch.Tensor],
72+
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
7373
model: torch.nn.Module,
7474
**kwargs,
7575
):
@@ -79,8 +79,8 @@ def on_run_end(
7979
self,
8080
*,
8181
adversary: Adversary,
82-
input: torch.Tensor | tuple,
83-
target: torch.Tensor | dict[str, Any] | tuple,
82+
input: torch.Tensor | Iterable[torch.Tensor],
83+
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
8484
model: torch.nn.Module,
8585
**kwargs,
8686
):

mart/attack/composer.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,36 @@
77
from __future__ import annotations
88

99
import abc
10-
from typing import Any
10+
from typing import Any, Iterable
1111

1212
import torch
1313

1414

1515
class Composer(abc.ABC):
1616
def __call__(
1717
self,
18-
perturbation: torch.Tensor | tuple,
18+
perturbation: torch.Tensor | Iterable[torch.Tensor],
1919
*,
20-
input: torch.Tensor | tuple,
21-
target: torch.Tensor | dict[str, Any] | tuple,
20+
input: torch.Tensor | Iterable[torch.Tensor],
21+
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
2222
**kwargs,
23-
) -> torch.Tensor | tuple:
24-
if isinstance(perturbation, tuple):
25-
input_adv = tuple(
23+
) -> torch.Tensor | Iterable[torch.Tensor]:
24+
if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor):
25+
return self.compose(perturbation, input=input, target=target)
26+
27+
elif (
28+
isinstance(perturbation, Iterable)
29+
and isinstance(input, Iterable) # noqa: W503
30+
and isinstance(target, Iterable) # noqa: W503
31+
):
32+
# FIXME: replace tuple with whatever input's type is
33+
return tuple(
2634
self.compose(perturbation_i, input=input_i, target=target_i)
2735
for perturbation_i, input_i, target_i in zip(perturbation, input, target)
2836
)
29-
else:
30-
input_adv = self.compose(perturbation, input=input, target=target)
3137

32-
return input_adv
38+
else:
39+
raise NotImplementedError
3340

3441
@abc.abstractmethod
3542
def compose(

mart/attack/enforcer.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
import abc
10-
from typing import Any
10+
from typing import Any, Iterable
1111

1212
import torch
1313

@@ -95,45 +95,36 @@ def verify(self, input_adv, *, input, target):
9595

9696

9797
class Enforcer:
98-
def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None:
99-
self.modality_constraints = modality_constraints
98+
def __init__(self, constraints: dict[str, Constraint]) -> None:
99+
self.constraints = list(constraints.values()) # intentionally ignore keys
100100

101101
@torch.no_grad()
102-
def _enforce(
102+
def __call__(
103103
self,
104-
input_adv: torch.Tensor,
104+
input_adv: torch.Tensor | Iterable[torch.Tensor],
105105
*,
106-
input: torch.Tensor,
107-
target: torch.Tensor | dict[str, Any],
108-
modality: str,
106+
input: torch.Tensor | Iterable[torch.Tensor],
107+
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
108+
**kwargs,
109109
):
110-
for constraint in self.modality_constraints[modality].values():
111-
constraint(input_adv, input=input, target=target)
110+
if isinstance(input_adv, torch.Tensor) and isinstance(input, torch.Tensor):
111+
self.enforce(input_adv, input=input, target=target)
112+
113+
elif (
114+
isinstance(input_adv, Iterable)
115+
and isinstance(input, Iterable) # noqa: W503
116+
and isinstance(target, Iterable) # noqa: W503
117+
):
118+
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
119+
self.enforce(input_adv_i, input=input_i, target=target_i)
112120

113-
def __call__(
121+
@torch.no_grad()
122+
def enforce(
114123
self,
115-
input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
124+
input_adv: torch.Tensor,
116125
*,
117-
input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
126+
input: torch.Tensor,
118127
target: torch.Tensor | dict[str, Any],
119-
modality: str = "constraints",
120-
**kwargs,
121128
):
122-
assert type(input_adv) == type(input)
123-
124-
if isinstance(input_adv, torch.Tensor):
125-
# Finally we can verify constraints on tensor, per its modality.
126-
# Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities.
127-
self._enforce(input_adv, input=input, target=target, modality=modality)
128-
elif isinstance(input_adv, dict):
129-
# The dict input has modalities specified in keys, passing them recursively.
130-
for modality in input_adv:
131-
self(input_adv[modality], input=input[modality], target=target, modality=modality)
132-
elif isinstance(input_adv, (list, tuple)):
133-
# We assume a modality-dictionary only contains tensors, but not list/tuple.
134-
assert modality == "constraints"
135-
# The list or tuple input is a collection of sub-input and sub-target.
136-
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
137-
self(input_adv_i, input=input_i, target=target_i, modality=modality)
138-
else:
139-
raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.")
129+
for constraint in self.constraints:
130+
constraint(input_adv, input=input, target=target)

mart/attack/gradient_modifier.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,44 +6,40 @@
66

77
from __future__ import annotations
88

9-
import abc
109
from typing import Iterable
1110

1211
import torch
1312

1413
__all__ = ["GradientModifier"]
1514

1615

17-
class GradientModifier(abc.ABC):
16+
class GradientModifier:
1817
"""Gradient modifier base class."""
1918

20-
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
21-
pass
22-
23-
24-
class Sign(GradientModifier):
2519
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
2620
if isinstance(parameters, torch.Tensor):
2721
parameters = [parameters]
2822

29-
parameters = [p for p in parameters if p.grad is not None]
23+
[self.modify_(parameter) for parameter in parameters]
24+
25+
@torch.no_grad()
26+
def modify_(self, parameter: torch.Tensor) -> None:
27+
pass
28+
3029

31-
for p in parameters:
32-
p.grad.detach().sign_()
30+
class Sign(GradientModifier):
31+
@torch.no_grad()
32+
def modify_(self, parameter: torch.Tensor) -> None:
33+
parameter.grad.sign_()
3334

3435

3536
class LpNormalizer(GradientModifier):
3637
"""Scale gradients by a certain L-p norm."""
3738

3839
def __init__(self, p: int | float):
39-
self.p = p
40-
41-
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
42-
if isinstance(parameters, torch.Tensor):
43-
parameters = [parameters]
44-
45-
parameters = [p for p in parameters if p.grad is not None]
40+
self.p = float(p)
4641

47-
for p in parameters:
48-
p_norm = torch.norm(p.grad.detach(), p=self.p)
49-
p.grad.detach().div_(p_norm)
42+
@torch.no_grad()
43+
def modify_(self, parameter: torch.Tensor) -> None:
44+
p_norm = torch.norm(parameter.grad.detach(), p=self.p)
45+
parameter.grad.detach().div_(p_norm)

0 commit comments

Comments
 (0)