From de8ae9652df8bf46a9cf27f5e85e8780279e35a5 Mon Sep 17 00:00:00 2001 From: Mark Obozov Date: Fri, 11 Jul 2025 20:41:49 +0300 Subject: [PATCH] reimplement batched_rewards --- torchtune/dev/rl/rewards.py | 47 ++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/torchtune/dev/rl/rewards.py b/torchtune/dev/rl/rewards.py index 8d1ec1e79f..d4c6abe91d 100644 --- a/torchtune/dev/rl/rewards.py +++ b/torchtune/dev/rl/rewards.py @@ -7,10 +7,15 @@ import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Union import torch +from torchtune.modules.transforms.tokenizers import ( + HuggingFaceModelTokenizer, + ModelTokenizer, +) + @dataclass class RewardOutput: @@ -216,3 +221,43 @@ def __call__( }, successes=successes, ) + + +def batched_rewards( + tokenizer: Union[ModelTokenizer, HuggingFaceModelTokenizer], + completions: torch.Tensor, + answers: list[list[str]], + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + reward_funcs = [ + ThinkingAnswerFormattingReward("think", "answer", 1.0), + FormattedMathCorrectnessReward("think", 50.0), + ] + + num_reward_funcs = len(reward_funcs) + batch_size, group_size, _ = completions.shape + + rewards_tensor = torch.zeros( + batch_size, group_size, num_reward_funcs, dtype=torch.bfloat16, device=device + ) + successes_tensor = torch.zeros_like(rewards_tensor) + + completions_list = [] + answers_list = [] + for b in range(batch_size): + for g in range(group_size): + completion_text = tokenizer.decode(completions[b, g].tolist()) + completions_list.append(completion_text) + answers_list.append(answers[b][g]) + + completions_ids = torch.zeros(len(completions_list)) # dummy + + for rw_idx, reward_func in enumerate(reward_funcs): + reward_obj = reward_func(completions_ids, completions_list, answers_list) + batch_reward = reward_obj.total_reward.view(batch_size, group_size) + batch_success = reward_obj.successes.view(batch_size, group_size) + + rewards_tensor[:, :, rw_idx] = batch_reward + successes_tensor[:, :, rw_idx] = batch_success + + return rewards_tensor, successes_tensor