From f86ddf4380cdbcb2fd37bb886b8da7fa7d20b613 Mon Sep 17 00:00:00 2001 From: Iden Kalemaj Date: Thu, 19 Dec 2024 11:23:42 -0800 Subject: [PATCH] Separate function for preparing criterion in PrivacyEngine (#703) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/703 Having a separate function for preparing the criterion makes it easy to build custom extensions of PrivacyEnginge for methods that require a different DPLoss class, e.g., adaptive clipping. Reviewed By: EnayatUllah Differential Revision: D67458234 fbshipit-source-id: 9fca64fcde7714708ac1cb9a35a991099606f449 --- opacus/privacy_engine.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index 1af891c4..558c8f8e 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -212,6 +212,26 @@ def _prepare_model( loss_reduction=loss_reduction, ) + def _prepare_criterion( + self, + *, + module: GradSampleModule, + optimizer: DPOptimizer, + criterion=nn.CrossEntropyLoss(), + loss_reduction: str = "mean", + **kwargs, + ) -> DPLossFastGradientClipping: + """ + Args: + module: GradSampleModule used for training, + optimizer: DPOptimizer used for training, + criterion: Loss function used for training, + loss_reduction: "mean" or "sum", indicates if the loss reduction (for aggregating the gradients) + + Prepare the DP loss class, which packages the two backward passes for fast gradient clipping. + """ + return DPLossFastGradientClipping(module, optimizer, criterion, loss_reduction) + def is_compatible( self, *, @@ -403,9 +423,14 @@ def make_private( self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate) ) if grad_sample_mode == "ghost": - criterion = DPLossFastGradientClipping( - module, optimizer, criterion, loss_reduction + criterion = self._prepare_criterion( + module=module, + optimizer=optimizer, + criterion=criterion, + loss_reduction=loss_reduction, + **kwargs, ) + return module, optimizer, criterion, data_loader return module, optimizer, data_loader