Skip to content

Question about discrete score understanding #16

@tulvgengenr

Description

@tulvgengenr

Hello! I am very interested in your paper “Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution”, it is very excellent paper!

In my project, I would like to use the theory of this paper to optimize, but there is a question I would like to ask you. I would be very grateful for your answer.

The work that my project is based on uses discrete diffusion, who uses $Q^{absorb}$ for additive noise scheduling. For a token, it has a $1-e^{-sigma(t)}$ probability of becoming “mask_id” and a $e^{sigma(t)}$ probability of remaining the same. But the model outputs a probability distribution for each token in the sequence (same as the LLM), and the output result tensor shape is (B, L, N). How can I convert the model results to scores? Thus being able to calculate the LOSS directly using your code.

To my knowledge, I think I need to calculate the ratio of the probability that each position is predicted to be “mask_id” to the probability that it is predicted to be itself. For example the following code:

def p2score(pred).
    '''
    pred: [B, L, N]
    '''
    # Probability for mask (vocab_size-1) / itself
    pred_mask = pred[:, :, -1].repeat(1, 1, pred.shape[-1])
    score = pred_mask / pred
    return score

Then you can use your code to calculate the loss, here is your code:

    def score_entropy(self, score, sigma, x, x0):
        rel_ind = x == self.dim - 1
        esigm1 = torch.where(
            sigma < 0.5,
            torch.expm1(sigma),
            torch.exp(sigma) - 1
        )
        ratio = 1 / esigm1.expand_as(x)[rel_ind]
        other_ind = x0[rel_ind]
        # negative_term
        neg_term = ratio * torch.gather(score[rel_ind], -1, other_ind[... , None]).squeeze(-1)
        #positive term
        pos_term = score[rel_ind][:, :-1].exp().sum(dim=-1)
        # constant term
        const = ratio * (ratio.log() - 1)
        entropy = torch.zeros(*x.shape, device=x.device)
        entropy[rel_ind] += pos_term - neg_term + const
        return entropy

Is my understanding correct and I would like to get your reply!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions