Skip to content

Commit 0d186a4

Browse files
EnayatUllahfacebook-github-bot
authored andcommitted
Add per-sample gradient norm computation as a functionality (#724)
Summary: Pull Request resolved: #724 Per-sample gradient norm is computed for Ghost Clipping, but it can be useful generally. Exposed it as a functionality. ``` ... loss.backward() per_sample_norms = model.per_sample_gradient_norms ``` Reviewed By: iden-kalemaj Differential Revision: D68634969 fbshipit-source-id: 7d5cb8a05de11d7492d3c1ae7f7384243cc03c73
1 parent e4eb3fb commit 0d186a4

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

Diff for: opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py

+16
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
121121
self.max_grad_norm = max_grad_norm
122122
self.use_ghost_clipping = use_ghost_clipping
123+
self._per_sample_gradient_norms = None
123124

124125
def get_clipping_coef(self) -> torch.Tensor:
125126
"""Get per-example gradient scaling factor for clipping."""
@@ -131,6 +132,7 @@ def get_norm_sample(self) -> torch.Tensor:
131132
norm_sample = torch.stack(
132133
[param._norm_sample for param in self.trainable_parameters], dim=0
133134
).norm(2, dim=0)
135+
self.per_sample_gradient_norms = norm_sample
134136
return norm_sample
135137

136138
def capture_activations_hook(
@@ -231,3 +233,17 @@ def capture_backprops_hook(
231233
if len(module.activations) == 0:
232234
if hasattr(module, "max_batch_len"):
233235
del module.max_batch_len
236+
237+
@property
238+
def per_sample_gradient_norms(self) -> torch.Tensor:
239+
"""Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings"""
240+
if self._per_sample_gradient_norms is not None:
241+
return self._per_sample_gradient_norms
242+
else:
243+
raise AttributeError(
244+
"per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property."
245+
)
246+
247+
@per_sample_gradient_norms.setter
248+
def per_sample_gradient_norms(self, value):
249+
self._per_sample_gradient_norms = value

0 commit comments

Comments
 (0)