Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions configs/composer/callbacks/log_gradient_variance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
log_gradient_variance:
_target_: src.custom_composer.callbacks.LogGradientVariance
accumulation_steps: 10
log_frequency: 1
include_embedding_params: false
1 change: 1 addition & 0 deletions configs/composer/default_composer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ defaults:
- lr_monitor
- speed_monitor
- log_gradient_norms
- log_gradient_variance
- hf_compatible_checkpointing
- save_best_checkpointing
- loggers: wandb
Expand Down
92 changes: 85 additions & 7 deletions src/custom_composer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,7 @@ def batch_end(self, state: State, logger: Logger) -> None:


class LogGradientNorms(Callback):
"""Log gradient norms of non-embedding parameters.

This callback computes and logs the L2 norm of gradients for all parameters
that are not embeddings (i.e., parameters whose name does not contain "embed").
"""
"""Log gradient norms of model parameters."""

def __init__(
self,
Expand Down Expand Up @@ -503,14 +499,96 @@ def after_backward(self, state: State, logger: Logger) -> None:
# Log total gradient norm across all non-embedding parameters
total_norm = total_norm ** (1.0 / 2)
if self.include_embedding_params:
metrics["grad_norm/total"] = total_norm
metrics["grad_stats/norm"] = total_norm
else:
metrics["grad_norm/total_non_embedding"] = total_norm
metrics["grad_stats/norm_non_embedding"] = total_norm

if metrics:
logger.log_metrics(metrics)


class LogGradientVariance(Callback):
"""Log variance over multiple gradient updates."""

def __init__(
self,
accumulation_steps: int = 10,
log_frequency: int = 1,
include_embedding_params: bool = False,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.accumulation_steps = accumulation_steps
self.log_frequency = log_frequency
self.include_embedding_params = include_embedding_params
self.step_count = 0
self.accumulated_grads: list[dict[str, torch.Tensor]] = []

@staticmethod
def _is_embedding_param(param_name: str) -> bool:
"""Check if a parameter is an embedding parameter."""
return "embed" in param_name.lower()

@staticmethod
def _get_model(state: State):
"""Get the model, handling wrapped models."""
if hasattr(state.model, "module"):
return state.model.module.model
return state.model.model

def after_train_batch(self, state: State, logger: Logger) -> None:
"""Accumulate gradients and log variance across steps."""
self.step_count += 1
if self.step_count % self.log_frequency != 0:
return

model = self._get_model(state)

# Collect current gradients
current_grads = {}
for name, param in model.named_parameters():
if param.grad is not None and (
self.include_embedding_params or not self._is_embedding_param(name)
):
current_grads[name] = param.grad.data.clone().detach().view(-1)

# Accumulate gradients
self.accumulated_grads.append(current_grads)

# Check if we've accumulated enough steps
if len(self.accumulated_grads) >= self.accumulation_steps:
metrics = self._compute_variance_across_steps()
if metrics:
logger.log_metrics(metrics)
self.accumulated_grads.clear()

def _compute_variance_across_steps(self) -> dict[str, float]:
"""Compute variance of gradients across accumulated steps."""
if len(self.accumulated_grads) == 0:
return {}
metrics = {}
param_names = set(self.accumulated_grads[0].keys())
all_gradients = torch.stack(
[
torch.cat([step_grads[name] for name in param_names], dim=0)
for step_grads in self.accumulated_grads
],
dim=0,
)

total_variance = torch.norm(
all_gradients - all_gradients.mean(dim=0), p=2, dim=1
).pow(2)
total_variance = total_variance.sum() / (all_gradients.shape[0] - 1)

if self.include_embedding_params:
metrics["grad_stats/variance"] = total_variance.item()
else:
metrics["grad_stats/variance_non_embedding"] = total_variance.item()
return metrics


class WarmupWithFrozenEncoder(Callback):
def __init__(
self,
Expand Down