Skip to content

Question on how DMD loss is calculated in code #68

@ruben-nithyaganesh

Description

@ruben-nithyaganesh

In the code there is the line
loss = 0.5 * F.mse_loss(original_latents.float(), (original_latents-grad).detach().float(), reduction="mean")
(Line 241 in sd_guidance.py, and line 111 in edu_guidance.py).

I think I have a fairly good understanding of how original_latents and grad are computed and what they mean, but I am a little confused by this line that I have highlighted. Does this line correspond directly to the KL divergence gradient update ? (Equation 7. in the original DMD paper) If so, are there any resources or explanations around how this line is successfully computing that ?

Thanks for any help !

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions