|
7 | 7 | from __future__ import annotations |
8 | 8 |
|
9 | 9 | import abc |
10 | | -from typing import Any |
| 10 | +from typing import Any, Iterable |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 |
|
@@ -95,45 +95,36 @@ def verify(self, input_adv, *, input, target): |
95 | 95 |
|
96 | 96 |
|
97 | 97 | class Enforcer: |
98 | | - def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None: |
99 | | - self.modality_constraints = modality_constraints |
| 98 | + def __init__(self, constraints: dict[str, Constraint]) -> None: |
| 99 | + self.constraints = list(constraints.values()) # intentionally ignore keys |
100 | 100 |
|
101 | 101 | @torch.no_grad() |
102 | | - def _enforce( |
| 102 | + def __call__( |
103 | 103 | self, |
104 | | - input_adv: torch.Tensor, |
| 104 | + input_adv: torch.Tensor | Iterable[torch.Tensor], |
105 | 105 | *, |
106 | | - input: torch.Tensor, |
107 | | - target: torch.Tensor | dict[str, Any], |
108 | | - modality: str, |
| 106 | + input: torch.Tensor | Iterable[torch.Tensor], |
| 107 | + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], |
| 108 | + **kwargs, |
109 | 109 | ): |
110 | | - for constraint in self.modality_constraints[modality].values(): |
111 | | - constraint(input_adv, input=input, target=target) |
| 110 | + if isinstance(input_adv, torch.Tensor) and isinstance(input, torch.Tensor): |
| 111 | + self.enforce(input_adv, input=input, target=target) |
| 112 | + |
| 113 | + elif ( |
| 114 | + isinstance(input_adv, Iterable) |
| 115 | + and isinstance(input, Iterable) # noqa: W503 |
| 116 | + and isinstance(target, Iterable) # noqa: W503 |
| 117 | + ): |
| 118 | + for input_adv_i, input_i, target_i in zip(input_adv, input, target): |
| 119 | + self.enforce(input_adv_i, input=input_i, target=target_i) |
112 | 120 |
|
113 | | - def __call__( |
| 121 | + @torch.no_grad() |
| 122 | + def enforce( |
114 | 123 | self, |
115 | | - input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], |
| 124 | + input_adv: torch.Tensor, |
116 | 125 | *, |
117 | | - input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], |
| 126 | + input: torch.Tensor, |
118 | 127 | target: torch.Tensor | dict[str, Any], |
119 | | - modality: str = "constraints", |
120 | | - **kwargs, |
121 | 128 | ): |
122 | | - assert type(input_adv) == type(input) |
123 | | - |
124 | | - if isinstance(input_adv, torch.Tensor): |
125 | | - # Finally we can verify constraints on tensor, per its modality. |
126 | | - # Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities. |
127 | | - self._enforce(input_adv, input=input, target=target, modality=modality) |
128 | | - elif isinstance(input_adv, dict): |
129 | | - # The dict input has modalities specified in keys, passing them recursively. |
130 | | - for modality in input_adv: |
131 | | - self(input_adv[modality], input=input[modality], target=target, modality=modality) |
132 | | - elif isinstance(input_adv, (list, tuple)): |
133 | | - # We assume a modality-dictionary only contains tensors, but not list/tuple. |
134 | | - assert modality == "constraints" |
135 | | - # The list or tuple input is a collection of sub-input and sub-target. |
136 | | - for input_adv_i, input_i, target_i in zip(input_adv, input, target): |
137 | | - self(input_adv_i, input=input_i, target=target_i, modality=modality) |
138 | | - else: |
139 | | - raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.") |
| 129 | + for constraint in self.constraints: |
| 130 | + constraint(input_adv, input=input, target=target) |
0 commit comments