diff --git a/opacus/grad_sample/grad_sample_module.py b/opacus/grad_sample/grad_sample_module.py index 19b5ffa6..12111531 100644 --- a/opacus/grad_sample/grad_sample_module.py +++ b/opacus/grad_sample/grad_sample_module.py @@ -145,6 +145,20 @@ def __init__( force_functorch=force_functorch, ) + def requires_grad_(self, requires_grad: bool = True) -> nn.Module: + "Rewrite requires_grad_ to add/remove hooks based on requires_grad value" + if requires_grad: + # Attack hook to the module + self.add_hooks( + loss_reduction=self.loss_reduction, + batch_first=self.batch_first, + force_functorch=self.force_functorch, + ) + else: + # Remove hooks + self.remove_hooks() + return super().requires_grad_(requires_grad) + def forward(self, *args, **kwargs): return self._module(*args, **kwargs) diff --git a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py index 5a9adbb9..3214843f 100644 --- a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py +++ b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py @@ -117,7 +117,7 @@ def __init__( strict=strict, force_functorch=force_functorch, ) - self.trainable_parameters = [p for _, p in trainable_parameters(self._module)] + self.all_parameters = [p for p in self.parameters()] self.max_grad_norm = max_grad_norm self.use_ghost_clipping = use_ghost_clipping self._per_sample_gradient_norms = None @@ -130,7 +130,12 @@ def get_clipping_coef(self) -> torch.Tensor: def get_norm_sample(self) -> torch.Tensor: """Get per-example gradient norms.""" norm_sample = torch.stack( - [param._norm_sample for param in self.trainable_parameters], dim=0 + [ + param._norm_sample + for param in self.all_parameters + if param.requires_grad + ], + dim=0, ).norm(2, dim=0) self.per_sample_gradient_norms = norm_sample return norm_sample