diff --git a/grpo_qwen2vl.py b/grpo_qwen2vl.py index fd53ea6..c054e8e 100644 --- a/grpo_qwen2vl.py +++ b/grpo_qwen2vl.py @@ -938,9 +938,20 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Compute the loss advantages = inputs["advantages"] + + # CHANGED: Explictly calculate batch size from advantages tensor + # Ensure all tensors have matching batch size by explicitly slicing them + # This prevents potential shape mismatches during loss computation + batch_size = advantages.size(0) + completion_mask = completion_mask[:batch_size] + per_token_logps = per_token_logps[:batch_size] + ref_per_token_logps = ref_per_token_logps[:batch_size] + + old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach() + old_per_token_logps = old_per_token_logps[:batch_size] + # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see # _generate_and_score_completions) and use per_token_logps.detach() instead. - old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach() coef_1 = torch.exp(per_token_logps - old_per_token_logps) coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon) per_token_loss1 = coef_1 * advantages.unsqueeze(1)