Skip to content

Commit 4b0cd91

Browse files
iden-kalemajfacebook-github-bot
authored andcommitted
Add **kwargs to all optimizer classes (#710)
Summary: Pull Request resolved: #710 Purpose: To enable creating custom PrivacyEngines that extend the PrivacyEngine class and take in additional parameters. Fix prior diff: D67456352 Reviewed By: HuanyuZhang Differential Revision: D67953655 fbshipit-source-id: 70aef7571e012a370d6a0fd04948eccee06c9a0d
1 parent 3934851 commit 4b0cd91

8 files changed

+13
-0
lines changed

opacus/optimizers/adaclipoptimizer.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
loss_reduction: str = "mean",
5454
generator=None,
5555
secure_mode: bool = False,
56+
**kwargs,
5657
):
5758
super().__init__(
5859
optimizer,

opacus/optimizers/ddp_perlayeroptimizer.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
loss_reduction: str = "mean",
4949
generator=None,
5050
secure_mode: bool = False,
51+
**kwargs,
5152
):
5253
self.rank = torch.distributed.get_rank()
5354
self.world_size = torch.distributed.get_world_size()
@@ -79,6 +80,7 @@ def __init__(
7980
loss_reduction: str = "mean",
8081
generator=None,
8182
secure_mode: bool = False,
83+
**kwargs,
8284
):
8385
self.rank = torch.distributed.get_rank()
8486
self.world_size = torch.distributed.get_world_size()

opacus/optimizers/ddpoptimizer.py

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
loss_reduction: str = "mean",
3939
generator=None,
4040
secure_mode: bool = False,
41+
**kwargs,
4142
):
4243
super().__init__(
4344
optimizer,
@@ -47,6 +48,7 @@ def __init__(
4748
loss_reduction=loss_reduction,
4849
generator=generator,
4950
secure_mode=secure_mode,
51+
**kwargs,
5052
)
5153
self.rank = torch.distributed.get_rank()
5254
self.world_size = torch.distributed.get_world_size()

opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
loss_reduction: str = "mean",
3939
generator=None,
4040
secure_mode: bool = False,
41+
**kwargs,
4142
):
4243
super().__init__(
4344
optimizer,
@@ -47,6 +48,7 @@ def __init__(
4748
loss_reduction=loss_reduction,
4849
generator=generator,
4950
secure_mode=secure_mode,
51+
**kwargs,
5052
)
5153
self.rank = torch.distributed.get_rank()
5254
self.world_size = torch.distributed.get_world_size()

opacus/optimizers/optimizer.py

+1
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def __init__(
205205
loss_reduction: str = "mean",
206206
generator=None,
207207
secure_mode: bool = False,
208+
**kwargs,
208209
):
209210
"""
210211

opacus/optimizers/optimizer_fast_gradient_clipping.py

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
loss_reduction: str = "mean",
6464
generator=None,
6565
secure_mode: bool = False,
66+
**kwargs,
6667
):
6768
"""
6869
@@ -91,6 +92,7 @@ def __init__(
9192
loss_reduction=loss_reduction,
9293
generator=generator,
9394
secure_mode=secure_mode,
95+
**kwargs,
9496
)
9597

9698
@property

opacus/optimizers/perlayeroptimizer.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
loss_reduction: str = "mean",
4040
generator=None,
4141
secure_mode: bool = False,
42+
**kwargs,
4243
):
4344
assert len(max_grad_norm) == len(params(optimizer))
4445
self.max_grad_norms = max_grad_norm
@@ -51,6 +52,7 @@ def __init__(
5152
loss_reduction=loss_reduction,
5253
generator=generator,
5354
secure_mode=secure_mode,
55+
**kwargs,
5456
)
5557

5658
def clip_and_accumulate(self):

opacus/privacy_engine.py

+1
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _prepare_optimizer(
136136
loss_reduction=loss_reduction,
137137
generator=generator,
138138
secure_mode=self.secure_mode,
139+
**kwargs,
139140
)
140141

141142
def _prepare_data_loader(

0 commit comments

Comments
 (0)