Skip to content

Commit

Permalink
Separate function for preparing criterion in PrivacyEngine (#703)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #703

Having a separate function for preparing the criterion makes it easy to build custom extensions of PrivacyEnginge for methods that require a different DPLoss class, e.g., adaptive clipping.

Reviewed By: EnayatUllah

Differential Revision: D67458234

fbshipit-source-id: 9fca64fcde7714708ac1cb9a35a991099606f449
  • Loading branch information
iden-kalemaj authored and facebook-github-bot committed Dec 19, 2024
1 parent 144bd2a commit f86ddf4
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,26 @@ def _prepare_model(
loss_reduction=loss_reduction,
)

def _prepare_criterion(
self,
*,
module: GradSampleModule,
optimizer: DPOptimizer,
criterion=nn.CrossEntropyLoss(),
loss_reduction: str = "mean",
**kwargs,
) -> DPLossFastGradientClipping:
"""
Args:
module: GradSampleModule used for training,
optimizer: DPOptimizer used for training,
criterion: Loss function used for training,
loss_reduction: "mean" or "sum", indicates if the loss reduction (for aggregating the gradients)
Prepare the DP loss class, which packages the two backward passes for fast gradient clipping.
"""
return DPLossFastGradientClipping(module, optimizer, criterion, loss_reduction)

def is_compatible(
self,
*,
Expand Down Expand Up @@ -403,9 +423,14 @@ def make_private(
self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate)
)
if grad_sample_mode == "ghost":
criterion = DPLossFastGradientClipping(
module, optimizer, criterion, loss_reduction
criterion = self._prepare_criterion(
module=module,
optimizer=optimizer,
criterion=criterion,
loss_reduction=loss_reduction,
**kwargs,
)

return module, optimizer, criterion, data_loader

return module, optimizer, data_loader
Expand Down

0 comments on commit f86ddf4

Please sign in to comment.