File tree Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Original file line number Diff line number Diff line change 1414 GaussianDiffusion ,
1515 LearnedVarianceGaussianDiffusion ,
1616 cosine_betas ,
17+ get_eps_and_var ,
1718)
1819
1920
@@ -126,15 +127,18 @@ def forward(self, batch):
126127 noise = th .randn_like (x_0 )
127128 x_t = self .diffusion .q_sample (x_0 , t , noise )
128129 model_out = self .model (x_t , t )
129- x_0_pred = self .diffusion .predict_x0_from_eps (x_t = x_t , t = t , eps = model_out )
130+ # eps = model_out
131+ # if isinstance(self.diffusion, LearnedVarianceGaussianDiffusion):
132+ # eps, _ = get_eps_and_var(model_out, model_out.shape[1] // 2)
133+ # x_0_pred = self.diffusion.predict_x0_from_eps(x_t=x_t, t=t, eps=model_out)
130134 d = dict (
131135 noise = noise ,
132136 model_out = model_out ,
133137 # for logging & LearnedVarianceGaussianDiffusion
134138 x_t = x_t ,
135139 t = t ,
136140 # for logging
137- x_0_pred = x_0_pred ,
141+ # x_0_pred=x_0_pred,
138142 # for LearnedVarianceGaussianDiffusion
139143 x_0 = x_0 ,
140144 )
Original file line number Diff line number Diff line change @@ -69,7 +69,7 @@ def get_interval(total, times):
6969 LRMonitor (),
7070 SpeedMonitor (window_size = 10 ),
7171 CheckpointSaver (save_interval = checkpoint_interval ),
72- DiffusionMonitor (interval = diffusion_log_interval ),
72+ # DiffusionMonitor(interval=diffusion_log_interval),
7373 ],
7474 )
7575 return trainer
You can’t perform that action at this time.
0 commit comments