Skip to content

Commit 86b4ab4

Browse files
HuanyuZhangfacebook-github-bot
authored andcommitted
Add gradient sample mode to the logging system (#735)
Summary: Pull Request resolved: #735 We add gradient sample mode of each submodule to the logging system, which is especially useful information when people want to check the compatibility of complex model architecture. Reviewed By: iden-kalemaj Differential Revision: D70255075 fbshipit-source-id: 5defe47bec759c9f66f1071a389f9314c20626bd
1 parent 0a70a1d commit 86b4ab4

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

Diff for: opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
create_or_accumulate_grad_sample,
2727
promote_current_grad_sample,
2828
)
29-
from opacus.utils.module_utils import requires_grad, trainable_parameters
29+
from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN
30+
from opacus.utils.module_utils import (
31+
requires_grad,
32+
trainable_modules,
33+
trainable_parameters,
34+
)
3035

3136

3237
logger = logging.getLogger(__name__)
@@ -109,6 +114,12 @@ def __init__(
109114
If ``strict`` is set to ``True`` and module ``m`` (or any of its
110115
submodules) includes a buffer.
111116
"""
117+
if logger.isEnabledFor(logging.INFO):
118+
self.log_module_gradient_sample_mode(
119+
module=m,
120+
force_functorch=force_functorch,
121+
use_ghost_clipping=use_ghost_clipping,
122+
)
112123

113124
super().__init__(
114125
m,
@@ -234,6 +245,43 @@ def capture_backprops_hook(
234245
if hasattr(module, "max_batch_len"):
235246
del module.max_batch_len
236247

248+
def log_module_gradient_sample_mode(
249+
self, module: nn.Module, *, force_functorch=False, use_ghost_clipping=True
250+
):
251+
"""
252+
Add logs to track gradient sample mode for each part of the module, including 1) Ghost Clipping, 2) Fast Gradient Clipping (hook mode), and 3) Fast Gradient Clipping (functorch mode).
253+
254+
Args:
255+
module: nn.Module to be checked
256+
force_functorch: If set to ``True``, will use functorch to compute
257+
all per sample gradients. Otherwise, functorch will be used only
258+
for layers without registered grad sampler methods.
259+
use_ghost_clipping: If set to ``True``, Ghost Clipping
260+
will be used for clipping gradients of supported layers. If ``False``, Fast
261+
Gradient Clipping will be used for all layers.
262+
"""
263+
for m_name, m in trainable_modules(module):
264+
if type(m) in [DPRNN, DPLSTM, DPGRU]:
265+
logger.info(
266+
f"Module name: {m_name}, module type: {type(m)}. No hook or functorch is added."
267+
)
268+
269+
elif use_ghost_clipping and type(m) in self.NORM_SAMPLERS:
270+
logger.info(
271+
f"Module name: {m_name}, module type: {type(m)}, under Ghost Clipping."
272+
)
273+
274+
else:
275+
if not force_functorch and type(m) in self.GRAD_SAMPLERS:
276+
# When functorch is not enforced, use FGC (hook mode) if the layer has a registered grad_sampler (supported). Otherwise, use FGC (functorch mode).
277+
logger.info(
278+
f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (hook mode)."
279+
)
280+
else:
281+
logger.info(
282+
f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (functorch mode)."
283+
)
284+
237285
@property
238286
def per_sample_gradient_norms(self) -> torch.Tensor:
239287
"""Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings"""

0 commit comments

Comments
 (0)