At first try, I write my own version of dpo loss. However my loss is 0.5147 which is not close to the answer 0.5785 in test_dpo.py.
To debug my code, I refactor dpo loss from the package trl. But the answer is still 0.5147 .
Can the community help me find if the test case is correct? The following is my refactored code from trl:
import torch
from transformers import PreTrainedTokenizerBase, PreTrainedModel, AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
from torch import nn
import torch
def compute_forward(
model: nn.Module, batch: dict[str, torch.LongTensor]
) -> torch.Tensor:
"""
Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
batch:
prompt_input_ids: torch.LongTensor shape [bsize, nseqlen]
prompt_attention_mask: torch.LongTensor shape [bsize, nseqlen]
completion_input_ids: torch.LongTensor shape [bsize, nseqlen]
completion_attention_mask: torch.LongTensor shape [bsize, nseqlen]
"""
prompt_input_ids = batch["prompt_input_ids"]
prompt_attention_mask = batch["prompt_attention_mask"]
completion_input_ids = batch["completion_input_ids"]
completion_attention_mask = batch["completion_attention_mask"]
# Concatenate the prompt and completion inputs
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
# Mask the prompt but not the completion for the loss
loss_mask = torch.cat(
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
dim=1,
)
# attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
# Offset the logits by one to align with the labels
labels = torch.roll(input_ids, shifts=-1, dims=1)
loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool()
labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = selective_log_softmax(logits, labels)
per_token_logps[~loss_mask] = 0
per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)
logps = per_token_logps[:, 1:].sum(-1)
return logps
def selective_log_softmax(logits, index) -> torch.Tensor:
"""
A memory-efficient implementation of the common `log_softmax -> gather` operation.
This function is equivalent to the following naive implementation:
```python
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
```
Args:
logits (`torch.Tensor`):
Logits tensor of shape `(..., num_classes)`.
index (`torch.Tensor`):
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
Returns:
`torch.Tensor`:
Gathered log probabilities with the same shape as `index`.
"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach
per_token_logps = []
for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
def compute_per_instance_dpo_loss(
lm: PreTrainedModel,
lm_ref: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
beta: float,
prompt: str,
response_chosen: str,
response_rejected: str,
) -> torch.Tensor:
"""
Given two language models (`lm`, and the "reference model" `lm_ref`),
their tokenizer, the DPO beta hyperparameter, a prompt and a pair
of responses to the prompt, computes the value of the DPO loss for this example.
lm: torch.nn.Module
Language model being trained.
lm_ref: torch.nn.Module
Reference language model.
tokenizer: PreTrainedTokenizerBase
Tokenizer for both language models.
beta: float
DPO beta hyperparameter.
prompt: str
Prompt for this instance of preference pair.
response_chosen: str
Preferred response to the prompt.
response_rejected: str
Rejected response to the prompt.
Returns:
torch.Tensor with the DPO loss for this example.
"""
prompt_token = tokenizer(prompt, return_tensors='pt')
good_response_token = tokenizer(response_chosen, return_tensors='pt')
bad_response_token = tokenizer(response_rejected, return_tensors='pt')
batch_good = {}
batch_good['prompt_input_ids'] = prompt_token['input_ids']
batch_good['prompt_attention_mask'] = prompt_token['attention_mask']
batch_good['completion_input_ids'] = good_response_token['input_ids']
batch_good['completion_attention_mask'] = good_response_token['attention_mask']
batch_bad = {}
batch_bad['prompt_input_ids'] = prompt_token['input_ids']
batch_bad['prompt_attention_mask'] = prompt_token['attention_mask']
batch_bad['completion_input_ids'] = bad_response_token['input_ids']
batch_bad['completion_attention_mask'] = bad_response_token['attention_mask']
model_logs_good = compute_forward(lm, batch_good)
model_logs_good_ref = compute_forward(lm_ref, batch_good)
model_logs_bad = compute_forward(lm, batch_bad)
model_logs_bad_ref = compute_forward(lm_ref, batch_bad)
logratios = model_logs_good - model_logs_bad
ref_logratios = model_logs_good_ref - model_logs_bad_ref
logits = logratios - ref_logratios
losses = -F.logsigmoid(beta * logits)
return losses.mean()
if __name__ == '__main__':
FIXTURES_PATH = './tests/fixtures'
tokenizer = AutoTokenizer.from_pretrained("gpt2")
m = AutoModelForCausalLM.from_pretrained(FIXTURES_PATH + "/tiny-gpt2")
m_ref = AutoModelForCausalLM.from_pretrained(FIXTURES_PATH + "/tiny-gpt2-ref")
prompt = "The quick brown fox jumps over"
good_response = "the lazy dog."
bad_response = "their crazy frog."
loss = compute_per_instance_dpo_loss(m, m_ref, tokenizer, 0.5, prompt, good_response, bad_response)
print(loss)
At first try, I write my own version of dpo loss. However my loss is
0.5147which is not close to the answer0.5785intest_dpo.py.To debug my code, I refactor dpo loss from the package trl. But the answer is still
0.5147.Can the community help me find if the test case is correct? The following is my refactored code from trl: