Skip to content

Commit 20d2078

Browse files
authored
Update GradientModifier to be an in-place operation (#123)
* Update GradientModifier to be an in-place operation * Update GradientModifier tests * Add Adversary gradient test
1 parent dbc6f65 commit 20d2078

File tree

4 files changed

+89
-22
lines changed

4 files changed

+89
-22
lines changed

mart/attack/gradient_modifier.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66

7+
from __future__ import annotations
8+
79
import abc
8-
from typing import Union
10+
from typing import Iterable
911

1012
import torch
1113

@@ -15,25 +17,33 @@
1517
class GradientModifier(abc.ABC):
1618
"""Gradient modifier base class."""
1719

18-
@abc.abstractmethod
19-
def __call__(self, grad: torch.Tensor) -> torch.Tensor:
20+
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
2021
pass
2122

2223

2324
class Sign(GradientModifier):
24-
def __call__(self, grad: torch.Tensor) -> torch.Tensor:
25-
return grad.sign()
25+
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
26+
if isinstance(parameters, torch.Tensor):
27+
parameters = [parameters]
28+
29+
parameters = [p for p in parameters if p.grad is not None]
30+
31+
for p in parameters:
32+
p.grad.detach().sign_()
2633

2734

2835
class LpNormalizer(GradientModifier):
2936
"""Scale gradients by a certain L-p norm."""
3037

31-
def __init__(self, p: Union[int, float]):
32-
super().__init__
33-
38+
def __init__(self, p: int | float):
3439
self.p = p
3540

36-
def __call__(self, grad: torch.Tensor) -> torch.Tensor:
37-
grad_norm = grad.norm(p=self.p)
38-
grad_normalized = grad / grad_norm
39-
return grad_normalized
41+
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
42+
if isinstance(parameters, torch.Tensor):
43+
parameters = [parameters]
44+
45+
parameters = [p for p in parameters if p.grad is not None]
46+
47+
for p in parameters:
48+
p_norm = torch.norm(p.grad.detach(), p=self.p)
49+
p.grad.detach().div_(p_norm)

mart/attack/perturber/perturber.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66

7+
from collections import namedtuple
78
from typing import Any, Dict, Optional, Union
89

910
import torch
@@ -71,7 +72,15 @@ def on_run_start(self, *, adversary, input, target, model, **kwargs):
7172

7273
# A backward hook that will be called when a gradient w.r.t the Tensor is computed.
7374
if self.gradient_modifier is not None:
74-
self.perturbation.register_hook(self.gradient_modifier)
75+
76+
def gradient_modifier(grad):
77+
# Create fake tensor with cloned grad so we can use in-place operations
78+
FakeTensor = namedtuple("FakeTensor", ["grad"])
79+
param = FakeTensor(grad=grad.clone())
80+
self.gradient_modifier([param])
81+
return param.grad
82+
83+
self.perturbation.register_hook(gradient_modifier)
7584

7685
self.initializer(self.perturbation)
7786

tests/test_adversary.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import mart
1414
from mart.attack import Adversary
15+
from mart.attack.gradient_modifier import Sign
1516
from mart.attack.perturber import Perturber
1617

1718

@@ -145,3 +146,42 @@ def model(input, target, model=None, **kwargs):
145146
# Simulate a new batch of data of different size.
146147
new_input_data = torch.cat([input_data, input_data])
147148
output3 = adversary(new_input_data, target_data, model=model)
149+
150+
151+
def test_adversary_gradient(input_data, target_data):
152+
composer = mart.attack.composer.Additive()
153+
enforcer = Mock()
154+
optimizer = partial(SGD, lr=1.0, maximize=True)
155+
156+
# Force zeros, positive and negative gradients
157+
def gain(logits):
158+
return (
159+
(0 * logits[0, :, :]).mean()
160+
+ (0.1 * logits[1, :, :]).mean() # noqa: W503
161+
+ (-0.1 * logits[2, :, :]).mean() # noqa: W503
162+
)
163+
164+
# Perturbation initialized as zero.
165+
def initializer(x):
166+
torch.nn.init.constant_(x, 0)
167+
168+
perturber = Perturber(initializer, Sign())
169+
170+
adversary = Adversary(
171+
composer=composer,
172+
enforcer=enforcer,
173+
perturber=perturber,
174+
optimizer=optimizer,
175+
max_iters=1,
176+
gain=gain,
177+
)
178+
179+
def model(input, target, model=None, **kwargs):
180+
return {"logits": adversary(input, target)}
181+
182+
adversary(input_data, target_data, model=model)
183+
input_adv = adversary(input_data, target_data)
184+
185+
perturbation = input_data - input_adv
186+
187+
torch.testing.assert_close(perturbation.unique(), torch.Tensor([-1, 0, 1]))

tests/test_gradient.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,23 @@
1111

1212

1313
def test_gradient_sign(input_data):
14-
gradient = Sign()
15-
output = gradient(input_data)
16-
expected_output = input_data.sign()
17-
torch.testing.assert_close(output, expected_output)
14+
# Don't share input_data with other tests, because the gradient would be changed.
15+
input_data = torch.tensor([1.0, 2.0, 3.0])
16+
input_data.grad = torch.tensor([-1.0, 3.0, 0.0])
1817

18+
grad_modifier = Sign()
19+
grad_modifier(input_data)
20+
expected_grad = torch.tensor([-1.0, 1.0, 0.0])
21+
torch.testing.assert_close(input_data.grad, expected_grad)
22+
23+
24+
def test_gradient_lp_normalizer():
25+
# Don't share input_data with other tests, because the gradient would be changed.
26+
input_data = torch.tensor([1.0, 2.0, 3.0])
27+
input_data.grad = torch.tensor([-1.0, 3.0, 0.0])
1928

20-
def test_gradient_lp_normalizer(input_data):
2129
p = 1
22-
gradient = LpNormalizer(p)
23-
output = gradient(input_data)
24-
expected_output = input_data / input_data.norm(p=p)
25-
torch.testing.assert_close(output, expected_output)
30+
grad_modifier = LpNormalizer(p)
31+
grad_modifier(input_data)
32+
expected_grad = torch.tensor([-0.25, 0.75, 0.0])
33+
torch.testing.assert_close(input_data.grad, expected_grad)

0 commit comments

Comments
 (0)