Skip to content

Commit 015215d

Browse files
committed
Disable Diffusion Monitor b/c it was causing dimensionality issues
1 parent cd39b54 commit 015215d

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

iddpm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
GaussianDiffusion,
1515
LearnedVarianceGaussianDiffusion,
1616
cosine_betas,
17+
get_eps_and_var,
1718
)
1819

1920

@@ -126,15 +127,18 @@ def forward(self, batch):
126127
noise = th.randn_like(x_0)
127128
x_t = self.diffusion.q_sample(x_0, t, noise)
128129
model_out = self.model(x_t, t)
129-
x_0_pred = self.diffusion.predict_x0_from_eps(x_t=x_t, t=t, eps=model_out)
130+
# eps = model_out
131+
# if isinstance(self.diffusion, LearnedVarianceGaussianDiffusion):
132+
# eps, _ = get_eps_and_var(model_out, model_out.shape[1] // 2)
133+
# x_0_pred = self.diffusion.predict_x0_from_eps(x_t=x_t, t=t, eps=model_out)
130134
d = dict(
131135
noise=noise,
132136
model_out=model_out,
133137
# for logging & LearnedVarianceGaussianDiffusion
134138
x_t=x_t,
135139
t=t,
136140
# for logging
137-
x_0_pred=x_0_pred,
141+
# x_0_pred=x_0_pred,
138142
# for LearnedVarianceGaussianDiffusion
139143
x_0=x_0,
140144
)

trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def get_interval(total, times):
6969
LRMonitor(),
7070
SpeedMonitor(window_size=10),
7171
CheckpointSaver(save_interval=checkpoint_interval),
72-
DiffusionMonitor(interval=diffusion_log_interval),
72+
# DiffusionMonitor(interval=diffusion_log_interval),
7373
],
7474
)
7575
return trainer

0 commit comments

Comments
 (0)