|
26 | 26 | create_or_accumulate_grad_sample,
|
27 | 27 | promote_current_grad_sample,
|
28 | 28 | )
|
29 |
| -from opacus.utils.module_utils import requires_grad, trainable_parameters |
| 29 | +from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN |
| 30 | +from opacus.utils.module_utils import ( |
| 31 | + requires_grad, |
| 32 | + trainable_modules, |
| 33 | + trainable_parameters, |
| 34 | +) |
30 | 35 |
|
31 | 36 |
|
32 | 37 | logger = logging.getLogger(__name__)
|
@@ -109,6 +114,12 @@ def __init__(
|
109 | 114 | If ``strict`` is set to ``True`` and module ``m`` (or any of its
|
110 | 115 | submodules) includes a buffer.
|
111 | 116 | """
|
| 117 | + if logger.isEnabledFor(logging.INFO): |
| 118 | + self.log_module_gradient_sample_mode( |
| 119 | + module=m, |
| 120 | + force_functorch=force_functorch, |
| 121 | + use_ghost_clipping=use_ghost_clipping, |
| 122 | + ) |
112 | 123 |
|
113 | 124 | super().__init__(
|
114 | 125 | m,
|
@@ -234,6 +245,43 @@ def capture_backprops_hook(
|
234 | 245 | if hasattr(module, "max_batch_len"):
|
235 | 246 | del module.max_batch_len
|
236 | 247 |
|
| 248 | + def log_module_gradient_sample_mode( |
| 249 | + self, module: nn.Module, *, force_functorch=False, use_ghost_clipping=True |
| 250 | + ): |
| 251 | + """ |
| 252 | + Add logs to track gradient sample mode for each part of the module, including 1) Ghost Clipping, 2) Fast Gradient Clipping (hook mode), and 3) Fast Gradient Clipping (functorch mode). |
| 253 | +
|
| 254 | + Args: |
| 255 | + module: nn.Module to be checked |
| 256 | + force_functorch: If set to ``True``, will use functorch to compute |
| 257 | + all per sample gradients. Otherwise, functorch will be used only |
| 258 | + for layers without registered grad sampler methods. |
| 259 | + use_ghost_clipping: If set to ``True``, Ghost Clipping |
| 260 | + will be used for clipping gradients of supported layers. If ``False``, Fast |
| 261 | + Gradient Clipping will be used for all layers. |
| 262 | + """ |
| 263 | + for m_name, m in trainable_modules(module): |
| 264 | + if type(m) in [DPRNN, DPLSTM, DPGRU]: |
| 265 | + logger.info( |
| 266 | + f"Module name: {m_name}, module type: {type(m)}. No hook or functorch is added." |
| 267 | + ) |
| 268 | + |
| 269 | + elif use_ghost_clipping and type(m) in self.NORM_SAMPLERS: |
| 270 | + logger.info( |
| 271 | + f"Module name: {m_name}, module type: {type(m)}, under Ghost Clipping." |
| 272 | + ) |
| 273 | + |
| 274 | + else: |
| 275 | + if not force_functorch and type(m) in self.GRAD_SAMPLERS: |
| 276 | + # When functorch is not enforced, use FGC (hook mode) if the layer has a registered grad_sampler (supported). Otherwise, use FGC (functorch mode). |
| 277 | + logger.info( |
| 278 | + f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (hook mode)." |
| 279 | + ) |
| 280 | + else: |
| 281 | + logger.info( |
| 282 | + f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (functorch mode)." |
| 283 | + ) |
| 284 | + |
237 | 285 | @property
|
238 | 286 | def per_sample_gradient_norms(self) -> torch.Tensor:
|
239 | 287 | """Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings"""
|
|
0 commit comments