In the code there is the line
loss = 0.5 * F.mse_loss(original_latents.float(), (original_latents-grad).detach().float(), reduction="mean")
(Line 241 in sd_guidance.py, and line 111 in edu_guidance.py).
I think I have a fairly good understanding of how original_latents and grad are computed and what they mean, but I am a little confused by this line that I have highlighted. Does this line correspond directly to the KL divergence gradient update ? (Equation 7. in the original DMD paper) If so, are there any resources or explanations around how this line is successfully computing that ?
Thanks for any help !
In the code there is the line
loss = 0.5 * F.mse_loss(original_latents.float(), (original_latents-grad).detach().float(), reduction="mean")(Line 241 in sd_guidance.py, and line 111 in edu_guidance.py).
I think I have a fairly good understanding of how
original_latentsandgradare computed and what they mean, but I am a little confused by this line that I have highlighted. Does this line correspond directly to the KL divergence gradient update ? (Equation 7. in the original DMD paper) If so, are there any resources or explanations around how this line is successfully computing that ?Thanks for any help !