Skip to content

Is the testcase of dpo loss correct? #4

@tfzzzh

Description

@tfzzzh

At first try, I write my own version of dpo loss. However my loss is 0.5147 which is not close to the answer 0.5785 in test_dpo.py.

To debug my code, I refactor dpo loss from the package trl. But the answer is still 0.5147 .

Can the community help me find if the test case is correct? The following is my refactored code from trl:

import torch
from transformers import PreTrainedTokenizerBase, PreTrainedModel, AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
from torch import nn
import torch


def compute_forward(
     model: nn.Module, batch: dict[str, torch.LongTensor]
) -> torch.Tensor:
    """
    Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

    batch:
        prompt_input_ids: torch.LongTensor shape [bsize, nseqlen]
        prompt_attention_mask: torch.LongTensor shape [bsize, nseqlen]
        completion_input_ids: torch.LongTensor shape [bsize, nseqlen]
        completion_attention_mask: torch.LongTensor shape [bsize, nseqlen]
    """
    prompt_input_ids = batch["prompt_input_ids"]
    prompt_attention_mask = batch["prompt_attention_mask"]
    completion_input_ids = batch["completion_input_ids"]
    completion_attention_mask = batch["completion_attention_mask"]

    # Concatenate the prompt and completion inputs
    input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
    attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
    # Mask the prompt but not the completion for the loss
    loss_mask = torch.cat(
        (torch.zeros_like(prompt_attention_mask), completion_attention_mask),
        dim=1,
    )

    # attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
    outputs = model(input_ids, attention_mask=attention_mask)
    logits = outputs.logits

    # Offset the logits by one to align with the labels
    labels = torch.roll(input_ids, shifts=-1, dims=1)
    loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool()

    labels[~loss_mask] = 0  # dummy token; we'll ignore the losses on these tokens later
    per_token_logps = selective_log_softmax(logits, labels)
    per_token_logps[~loss_mask] = 0
    per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)
    logps = per_token_logps[:, 1:].sum(-1)

    return logps


def selective_log_softmax(logits, index) -> torch.Tensor:
    """
    A memory-efficient implementation of the common `log_softmax -> gather` operation.

    This function is equivalent to the following naive implementation:
    ```python
    logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
    ```

    Args:
        logits (`torch.Tensor`):
            Logits tensor of shape `(..., num_classes)`.
        index (`torch.Tensor`):
            Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.

    Returns:
        `torch.Tensor`:
            Gathered log probabilities with the same shape as `index`.
    """
    if logits.dtype in [torch.float32, torch.float64]:
        selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
        # loop to reduce peak mem consumption
        logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
        per_token_logps = selected_logits - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)
    else:
        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach
        per_token_logps = []
        for row_logits, row_labels in zip(logits, index):  # loop to reduce peak mem consumption
            row_logps = F.log_softmax(row_logits, dim=-1)
            row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
            per_token_logps.append(row_per_token_logps)
        per_token_logps = torch.stack(per_token_logps)
    return per_token_logps


def compute_per_instance_dpo_loss(
    lm: PreTrainedModel,
    lm_ref: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    beta: float,
    prompt: str,
    response_chosen: str,
    response_rejected: str,
) -> torch.Tensor:
    """
    Given two language models (`lm`, and the "reference model" `lm_ref`),
    their tokenizer, the DPO beta hyperparameter, a prompt and a pair
    of responses to the prompt, computes the value of the DPO loss for this example.

    lm: torch.nn.Module
        Language model being trained.
    lm_ref: torch.nn.Module
        Reference language model.
    tokenizer: PreTrainedTokenizerBase
        Tokenizer for both language models.
    beta: float
        DPO beta hyperparameter.
    prompt: str
        Prompt for this instance of preference pair.
    response_chosen: str
        Preferred response to the prompt.
    response_rejected: str
        Rejected response to the prompt.

    Returns:
        torch.Tensor with the DPO loss for this example.
    """
    prompt_token = tokenizer(prompt, return_tensors='pt')
    good_response_token = tokenizer(response_chosen, return_tensors='pt')
    bad_response_token = tokenizer(response_rejected, return_tensors='pt')

    batch_good = {}
    batch_good['prompt_input_ids'] = prompt_token['input_ids']
    batch_good['prompt_attention_mask'] = prompt_token['attention_mask']
    batch_good['completion_input_ids'] = good_response_token['input_ids']
    batch_good['completion_attention_mask'] = good_response_token['attention_mask']

    batch_bad = {}
    batch_bad['prompt_input_ids'] = prompt_token['input_ids']
    batch_bad['prompt_attention_mask'] = prompt_token['attention_mask']
    batch_bad['completion_input_ids'] = bad_response_token['input_ids']
    batch_bad['completion_attention_mask'] = bad_response_token['attention_mask']

    model_logs_good = compute_forward(lm, batch_good)
    model_logs_good_ref = compute_forward(lm_ref, batch_good)
    model_logs_bad = compute_forward(lm, batch_bad)
    model_logs_bad_ref = compute_forward(lm_ref, batch_bad)
 

    logratios = model_logs_good - model_logs_bad
    ref_logratios = model_logs_good_ref - model_logs_bad_ref
    logits = logratios - ref_logratios
    losses =  -F.logsigmoid(beta * logits) 

    return losses.mean()


if __name__ == '__main__':
    FIXTURES_PATH = './tests/fixtures'
    tokenizer = AutoTokenizer.from_pretrained("gpt2")

    m = AutoModelForCausalLM.from_pretrained(FIXTURES_PATH + "/tiny-gpt2")
    m_ref = AutoModelForCausalLM.from_pretrained(FIXTURES_PATH + "/tiny-gpt2-ref")

    prompt = "The quick brown fox jumps over"
    good_response = "the lazy dog."
    bad_response = "their crazy frog."

    loss = compute_per_instance_dpo_loss(m, m_ref, tokenizer, 0.5, prompt, good_response, bad_response)
    print(loss)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions