Skip to content

Conversation

hhaAndroid
Copy link
Collaborator

the loss calculation will convert logits to float32, so for alignment, we also need to convert it to float32 here to prevent the ratio from being 1 during rl training

logits = F.linear(hidden_states, w, b)
return None, logits
# Note: the loss calculation will convert logits to float32, so for alignment,
# we also need to convert it to float32 here to prevent the ratio from being 1 during rl training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# we also need to convert it to float32 here to prevent the ratio from being 1 during rl training
# we also need to convert it to float32 to prevent the ratio from not being equal to 1 during on-policy rl training

@pppppM
Copy link
Collaborator

pppppM commented Sep 18, 2025

Due to the lack of use of chunk forward when computing old_logprobs, it may cause OOM (Out of Memory).
From the perspective of inference, the logits returned need to be bf16.
It is best to convert to float32 near the computation of logprobs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants