diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 7df08f40cf..a0e095fac2 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -389,9 +389,7 @@ def run_validation(step: int) -> None: nan_loss_count = nan_loss_count.item() logger.debug(f"Clipping gradients with max norm {config.optim.max_norm}") - grad_norm = clip_grad_norm_( - model.parameters(), max_norm=config.optim.max_norm, ep_enabled=parallel_dims.ep_enabled - ) + grad_norm = clip_grad_norm_(model.parameters(), max_norm=config.optim.max_norm) if grad_norm.device.type == "cpu": grad_norm = grad_norm.to(torch.device("cuda")) zero_grad_ratio = get_zero_gradient_ratio(model.parameters(), parallel_dims.dp_replicate)