Skip to content

Commit 3974b96

Browse files
authoredFeb 14, 2025
Merge pull request #1 from silky1708/main
Fixes in ddpm.py and main.py CompVis#851
2 parents 21f890f + 3cb1065 commit 3974b96

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed
 

‎ldm/models/diffusion/ddpm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ def p_losses(self, x_start, cond, t, noise=None):
10271027
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
10281028
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
10291029

1030-
logvar_t = self.logvar[t].to(self.device)
1030+
logvar_t = self.logvar.to(self.device)[t]
10311031
loss = loss_simple / torch.exp(logvar_t) + logvar_t
10321032
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
10331033
if self.learn_logvar:

‎main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def on_train_epoch_start(self, trainer, pl_module):
400400
torch.cuda.synchronize(trainer.root_gpu)
401401
self.start_time = time.time()
402402

403-
def on_train_epoch_end(self, trainer, pl_module, outputs):
403+
def on_train_epoch_end(self, trainer, pl_module):
404404
torch.cuda.synchronize(trainer.root_gpu)
405405
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
406406
epoch_time = time.time() - self.start_time

0 commit comments

Comments
 (0)
Please sign in to comment.