Skip to content

Commit 696fad6

Browse files
tpx818tpx
andauthored
[grpo]Tool rl: add reward func for ToolRL (#4694)
* init rl func * fix * fix nan * add comment * delete print * fix lint --------- Co-authored-by: tpx <tpx@DT>
1 parent 8077330 commit 696fad6

File tree

2 files changed

+249
-1
lines changed

2 files changed

+249
-1
lines changed

examples/train/grpo/plugin/plugin.py

Lines changed: 248 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
2+
import os
23
import re
34
import textwrap
5+
from collections import Counter
46
from copy import deepcopy
57
from typing import Dict, List, Optional
68

@@ -363,7 +365,6 @@ class CodeRewardByJudge0(ORM):
363365
PYTHON_ID = 71
364366

365367
def __init__(self):
366-
import os
367368
self.endpoint = os.getenv('JUDGE0_ENDPOINT')
368369
assert self.endpoint is not None, (
369370
'Judge0 endpoint is not set. Please set the JUDGE0_ENDPOINT environment variable.')
@@ -449,13 +450,259 @@ def __call__(self, completions, **kwargs) -> List[float]:
449450
return rewards
450451

451452

453+
# ref implementation: https://github.com/qiancheng0/ToolRL/blob/main/verl/utils/reward_score/rlla.py
454+
# arxiv paper: https://arxiv.org/abs/2504.13958
455+
# MAX1STEP30MAX3: enable Two stage reward Setting include Format and Correctness
456+
# SCHEDULEREWARD: enable Dynamic (Finegrained) reward Setting include Format and Correctness
457+
# Correctness Reward Granularity:
458+
# COARSEREWARD -> Coarse, INTERMEDIATEREWARD -> Intermediate, REFINEDREWARD -> Finegrained
459+
class ToolUseFormatReward(ORM):
460+
461+
def __init__(self):
462+
self.format_max_possible = 1.0
463+
self.format_min_possible = 0.0
464+
465+
def __call__(self, completions, solution, global_step, **kwargs) -> List[float]:
466+
max_possible_reward = self.format_max_possible
467+
min_possible_reward = self.format_min_possible
468+
# Two stage (Coarse) Setting, divide training into two phases. Format Reward in [0,0.5] if step < 30 else [0,1]
469+
if str(os.getenv('MAX1STEP30MAX3', 0)) == '1':
470+
if global_step >= 30:
471+
max_possible_reward = self.format_max_possible / 2
472+
min_possible_reward = self.format_min_possible / 2
473+
else:
474+
max_possible_reward = self.format_max_possible
475+
min_possible_reward = self.format_min_possible
476+
477+
# apply continuous interpolation between the two reward scales throughout training.
478+
if str(os.getenv('SCHEDULEREWARD', 0)) == '1':
479+
max_possible_reward = 2 - (2 - max_possible_reward) * global_step / 150
480+
min_possible_reward = -2 + (2 + min_possible_reward) * global_step / 150
481+
if max_possible_reward < 1.0:
482+
max_possible_reward = 1.0
483+
if min_possible_reward > -1.0:
484+
min_possible_reward = -1.0
485+
486+
rewards = []
487+
responses = completions
488+
489+
for response, ans in zip(responses, solution):
490+
reward = min_possible_reward
491+
if '<response>' in ans and '<tool_call>' not in ans:
492+
pattern = r'^<think>.*?</think>\n<response>.*?</response>$'
493+
if re.search(pattern, response,
494+
re.DOTALL) and response.count('<response>') == 1 and response.count('</response>') == 1:
495+
reward = max_possible_reward
496+
elif '<response>' not in ans and '<tool_call>' in ans:
497+
pattern = r'^<think>.*?</think>\n<tool_call>\n.*?\n</tool_call>$'
498+
if re.search(pattern, response,
499+
re.DOTALL) and response.count('<tool_call>') == 1 and response.count('</tool_call>') == 1:
500+
reward = max_possible_reward
501+
elif '<response>' in ans and '<tool_call>' in ans:
502+
pattern = r'^<think>.*?</think>\n<tool_call>\n.*?\n</tool_call>\n<response>.*?</response>$'
503+
if (re.search(pattern, response, re.DOTALL) and response.count('<tool_call>') == 1
504+
and response.count('</tool_call>') == 1 and response.count('<response>') == 1
505+
and response.count('</response>') == 1):
506+
reward = max_possible_reward
507+
else:
508+
pattern = r'^<think>.*?</think>$'
509+
if re.search(pattern, response, re.DOTALL):
510+
reward = max_possible_reward
511+
512+
rewards.append(reward)
513+
514+
return rewards
515+
516+
517+
class ToolUseLengthReward(ORM):
518+
519+
def __init__(self):
520+
self.length_max_possible = 1.0
521+
self.length_min_possible = 0.0
522+
523+
# customized reward functions: length
524+
def __call__(self, completions, solution, global_step, **kwargs):
525+
max_possible_reward = self.length_max_possible
526+
min_possible_reward = self.length_min_possible
527+
# SCHEDULELENGTH: enable Dynamic Length Reward
528+
if os.getenv('SCHEDULELENGTH', 0) == '1':
529+
max_reward_len = (640 - 384) * global_step / 105 + 384
530+
else:
531+
max_reward_len = 512
532+
"""Reward function that gives higher scores to longer completions."""
533+
responses = completions
534+
rewards = []
535+
536+
for response, ans in zip(responses, solution):
537+
if '<think>' not in response or '</think>' not in response:
538+
rewards.append(min_possible_reward)
539+
continue
540+
think_responses = response.split('<think>')[-1].split('</think>')[0].strip()
541+
reward = round(len(think_responses.split()) / max_reward_len, 2)
542+
if reward > 1.0:
543+
reward = 1.0
544+
545+
final_reward = reward * (max_possible_reward - min_possible_reward) + min_possible_reward
546+
rewards.append(final_reward)
547+
548+
return rewards
549+
550+
551+
class ToolUseCorrectnessReward(ORM):
552+
553+
def __init__(self):
554+
if str(os.getenv('CORRECTMAX1', 0)) == '1':
555+
self.tool_max_possible = 1.0
556+
self.tool_min_possible = -1.0
557+
else:
558+
self.tool_max_possible = 3.0
559+
self.tool_min_possible = -3.0
560+
561+
def match_score(self, list1, list2):
562+
if list1 == list2:
563+
return 1.0
564+
565+
if os.getenv('REFINEDREWARD', 0) == '1':
566+
if list1 != list2:
567+
return 0.0
568+
569+
if not list1 or not list2:
570+
return 0.0
571+
572+
count1 = Counter(list1) # Frequency count for list1
573+
count2 = Counter(list2) # Frequency count for list2
574+
575+
intersection = sum(min(count1[k], count2[k]) for k in count1.keys() & count2.keys())
576+
max_possible = len(list1) + len(list2) - intersection
577+
578+
return intersection / max_possible if max_possible > 0 else 0.0
579+
580+
def compute_tool_call_reward(self, gt_tools, pd_tools, max_possible_reward, min_possible_reward):
581+
if gt_tools == pd_tools:
582+
return max_possible_reward
583+
584+
if os.getenv('COARSEREWARD', 0) == '1':
585+
if gt_tools != pd_tools:
586+
return min_possible_reward
587+
588+
gt_names = [tool['name'] for tool in gt_tools]
589+
pd_names = [tool['name'] for tool in pd_tools]
590+
score = self.match_score(list(gt_names), list(pd_names))
591+
592+
local_max_possible = 1.0
593+
used_pd_indices = set() # Keep track of matched pd_tools
594+
595+
for gt_tool in gt_tools:
596+
gt_name = gt_tool['name']
597+
gt_params = gt_tool['parameters']
598+
599+
if str(os.getenv('INTERMEDIATEREWARD', 0)) == '1':
600+
local_max_possible += 1.0
601+
else:
602+
local_max_possible += 1.0 + len(gt_params)
603+
604+
best_match = None
605+
best_match_score = 0.0
606+
best_match_index = -1
607+
608+
# Find the best matching unused pd_tool
609+
for i, pd_tool in enumerate(pd_tools):
610+
if i in used_pd_indices or pd_tool['name'] != gt_name:
611+
continue
612+
613+
if str(os.getenv('INTERMEDIATEREWARD', 0)) == '1':
614+
if gt_tool == pd_tool:
615+
best_match = pd_tool
616+
best_match_index = i
617+
best_match_score = 1.0
618+
break
619+
else:
620+
continue
621+
622+
pd_params = pd_tool['parameters']
623+
param_score = self.match_score(list(gt_params.keys()), list(pd_params.keys()))
624+
625+
# Calculate correctness score for parameter values
626+
correctness_score = sum(1.0 for k, v in gt_params.items() if k in pd_params and pd_params[k] == v)
627+
628+
total_score = param_score + correctness_score
629+
630+
if total_score > best_match_score:
631+
best_match_score = total_score
632+
best_match = pd_tool
633+
best_match_index = i
634+
635+
if best_match:
636+
used_pd_indices.add(best_match_index)
637+
score += best_match_score
638+
639+
return (max_possible_reward - min_possible_reward) * score / local_max_possible + min_possible_reward
640+
641+
# custoimzed reward functions: tool call correctness
642+
def __call__(self, completions, solution, global_step, **kwargs):
643+
max_possible_reward = self.tool_max_possible
644+
min_possible_reward = self.tool_min_possible
645+
# two stage (Coarse) Setting, divide training into two phases.
646+
if str(os.getenv('MAX1STEP30MAX3', 0)) == '1':
647+
if global_step < 30:
648+
max_possible_reward = max_possible_reward / 3
649+
min_possible_reward = min_possible_reward / 3
650+
else:
651+
max_possible_reward = max_possible_reward
652+
min_possible_reward = min_possible_reward
653+
# apply continuous interpolation between the two reward scales throughout training.
654+
if str(os.getenv('SCHEDULEREWARD', 0)) == '1':
655+
max_possible_reward = (max_possible_reward - 2) * global_step / 150 + 2
656+
min_possible_reward = (min_possible_reward + 2) * global_step / 150 - 2
657+
if max_possible_reward > 3.0:
658+
max_possible_reward = 3.0
659+
if min_possible_reward < -3.0:
660+
min_possible_reward = -3.0
661+
662+
responses = completions
663+
rewards = []
664+
665+
for response, ans in zip(responses, solution):
666+
reward = 0.0
667+
668+
if '<tool_call>' not in ans:
669+
# if "<tool_call>" not in response and "</tool_call>" not in response:
670+
# reward = max_possible_reward
671+
# else:
672+
# reward = min_possible_reward
673+
rewards.append(reward)
674+
continue
675+
676+
gt_tool_call = ans.split('<tool_call>')[1].split('</tool_call>')[0].strip()
677+
gt_tools = gt_tool_call.split('\n')
678+
gt_tools = [json.loads(tool) for tool in gt_tools] # each diction contains "name" and "parameter"
679+
680+
try:
681+
# if the format is not correct, directly give the lowest possible score
682+
assert '<tool_call>' in response
683+
assert '</tool_call>' in response
684+
pd_tools = response.split('<tool_call>')[1].split('</tool_call>')[0].strip().split('\n')
685+
pd_tools = [json.loads(tool) for tool in pd_tools]
686+
reward = self.compute_tool_call_reward(gt_tools, pd_tools, max_possible_reward,
687+
min_possible_reward) # top reward is 2
688+
except (ValueError, IndexError, AssertionError):
689+
reward = min_possible_reward
690+
691+
rewards.append(reward)
692+
693+
return rewards
694+
695+
452696
orms['external_math_acc'] = MathAccuracy
453697
orms['external_math_format'] = MathFormat
454698
orms['external_countdown'] = CountdownORM
455699
orms['external_r1v_acc'] = MultiModalAccuracyORM
456700
orms['external_code_reward'] = CodeReward
457701
orms['external_code_format'] = CodeFormat
458702
orms['external_code_reward_by_judge0'] = CodeRewardByJudge0
703+
orms['external_tooluse_format_reward'] = ToolUseFormatReward
704+
orms['external_tooluse_length_reward'] = ToolUseLengthReward
705+
orms['external_tooluse_correct_reward'] = ToolUseCorrectnessReward
459706
"""
460707
TO CUSTOMIZE REWARD MODEL:
461708
Step 1: Define a Reward Class

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,7 @@ def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Te
898898
else:
899899
# Repeat all input columns (but "messages" and "completion") to match the number of generations
900900
reward_kwargs = RowPreprocessor.rows_to_batched(inputs)
901+
reward_kwargs['global_step'] = self.state.global_step
901902
output_reward_func = reward_func(completions, **reward_kwargs)
902903
output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
903904
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

0 commit comments

Comments
 (0)