diff --git a/src/train.py b/src/train.py index 5c3132e..73c1cbe 100644 --- a/src/train.py +++ b/src/train.py @@ -375,9 +375,8 @@ def guidance_value(key: str, default: float) -> float: sampler.set_epoch(epoch) epoch_metrics: Dict[str, torch.Tensor] = defaultdict(lambda: torch.zeros(1, device=device)) num_batches = 0 - optimizer.zero_grad() accum_counter = 0 - step_loss_accum = 0.0 + optimizer.zero_grad(set_to_none=True) if checkpoint_interval > 0 and epoch % checkpoint_interval == 0 and rank == 0: logger.info(f"Saving checkpoint at epoch {epoch}...") ckpt_path = f"{checkpoint_dir}/ep-{epoch:07d}.pt" @@ -395,20 +394,19 @@ def guidance_value(key: str, default: float) -> float: labels = labels.to(device) with torch.no_grad(): # TODO: wrap this in autocast? z = rae.encode(images) - optimizer.zero_grad(set_to_none=True) model_kwargs = dict(y=labels) with autocast(**autocast_kwargs): loss = transport.training_losses(ddp_model, z, model_kwargs)["loss"].mean() - loss.float() if scaler: scaler.scale(loss / grad_accum_steps).backward() else: (loss / grad_accum_steps).backward() - if clip_grad: - if scaler: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), clip_grad) - if global_step % grad_accum_steps == 0: + accum_counter += 1 + if accum_counter == grad_accum_steps: + if clip_grad: + if scaler: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), clip_grad) if scaler: scaler.step(optimizer) scaler.update() @@ -417,6 +415,8 @@ def guidance_value(key: str, default: float) -> float: if scheduler is not None: scheduler.step() update_ema(ema_model, ddp_model.module, decay=ema_decay) + optimizer.zero_grad(set_to_none=True) + accum_counter = 0 running_loss += loss.item() epoch_metrics['loss'] += loss.detach()