-
Notifications
You must be signed in to change notification settings - Fork 375
[refactor] refactor rl data structure in dataflow #1110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,9 +42,9 @@ def init_config(self): | |
tensor_parallel_size=8, | ||
) | ||
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig | ||
gsm8k_judger_config = GSM8KJudgerConfig() | ||
gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") | ||
self.judger_cfg = JudgerConfig( | ||
reward_judger_configs={"openai/gsm8k": gsm8k_judger_config} | ||
reward_judger_configs=[gsm8k_judger_config] | ||
) | ||
|
||
self.eval_dataset_cfg = [ | ||
|
@@ -82,7 +82,7 @@ def tearDown(self): | |
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") | ||
def test_lmdeploy_evaluator(self): | ||
def custom_compute_metric(samples): | ||
return {"custom_accuracy": sum(s["reward"] > 0 for s in samples) / len(samples)} | ||
return {"custom_accuracy": sum(s.env.judger.reward["weighted_reward"] > 0 for s in samples) / len(samples)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 返回字典后,这个地方就简单了 |
||
|
||
evaluator_cfg = EvaluatorConfig( | ||
dataset_cfg=self.eval_dataset_cfg, | ||
|
@@ -93,7 +93,6 @@ def custom_compute_metric(samples): | |
) | ||
evaluator = Evaluator.remote(evaluator_cfg, self.test_env) | ||
correctness = ray.get(evaluator.run.remote(sample_params=self.sample_params)) | ||
|
||
custom_evaluator_cfg = EvaluatorConfig( | ||
dataset_cfg=self.eval_dataset_cfg, | ||
tokenizer=self.tokenizer, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,49 +6,63 @@ | |
import ray | ||
import unittest | ||
import numpy as np | ||
|
||
from uuid import uuid4 | ||
from xtuner.v1.ray.environment import SingleTurnEnvironment | ||
from xtuner.v1.ray.config.worker import RolloutConfig | ||
from xtuner.v1.ray.accelerator import AcceleratorResourcesConfig, AutoAcceleratorWorkers | ||
from xtuner.v1.ray.judger.controller import JudgerController, JudgerConfig | ||
from xtuner.v1.datasets.data_item import RLTextDataItem | ||
from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLDatasetItem, RLEnvDataItem, RLRolloutResponseItem, RLUIDItem | ||
|
||
|
||
MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] | ||
DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] | ||
VERL_ROLLOUT_DATA_PATH = os.environ["VERL_ROLLOUT_DATA_PATH"] | ||
|
||
FAKE_INPUT_DATA_ITEM = { | ||
'messages': [{ | ||
'role': 'user', 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"' | ||
}], | ||
'num_tokens': 62, | ||
'reward_model': {'ground_truth': '72', 'style': 'rule'}, | ||
'ability': 'math', | ||
'data_source': {'openai/gsm8k': 1.0}, | ||
'extra_info': {'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72', 'index': 0, 'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'split': 'train', 'raw_prompt': '<|im_start|>user\nNatalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####".<|im_end|>\n<|im_start|>assistant\n'}, | ||
'env': 'test_env', | ||
'group_id': 255971142656329732139546771377476227093, | ||
'prompt_id': 22175756018538642401581407443664245296, | ||
'retry_times': 0} | ||
|
||
FAKE_JUDGER_INPUT_ITEM = copy.deepcopy(FAKE_INPUT_DATA_ITEM) | ||
FAKE_JUDGER_INPUT_ITEM["response_str"] = "<think>\nOkay, let's see. Natalia sold clips to 48 friends in April. Then in May, she sold half as many. So first, I need to figure out how many she sold in May. Half of 48 is 24, right? Because 48 divided by 2 is 24. So in May, she sold 24 clips.\n\nNow, to find the total number of clips sold in both months, I need to add the number from April and May together. That would be 48 (April) plus 24 (May). Let me do the addition: 48 + 24. Hmm, 40 + 20 is 60, and 8 + 4 is 12. So 60 + 12 is 72. So altogether, she sold 72 clips.\n\nWait, let me check that again. 48 plus 24. Yes, 48 + 20 is 68, then plus 4 more is 72. Yep, that seems right. So the total is 72.\n</think>\n\nNatalia sold 48 clips in April. In May, she sold half as many, which is 48 ÷ 2 = 24 clips. Adding both months together: 48 + 24 = 72. \n\n#### 72" | ||
FAKE_JUDGER_INPUT_ITEM_MULTI_DATA = [FAKE_JUDGER_INPUT_ITEM] * 2 | ||
FAKE_JUDGER_INPUT_ITEM = RLDataFlowItem( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是支持传入一个大的 dict,然后自动转 RLDataFlowItem? 而不需要每个对象都要构建 |
||
uid = RLUIDItem(action_id=uuid4().int, | ||
observation_id=uuid4().int), | ||
data = RLDatasetItem( | ||
messages=[{ | ||
'role': 'user', 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"' | ||
}], | ||
num_tokens=62, | ||
reward_model={'ground_truth': '72', 'style': 'rule'}, | ||
ability='math', | ||
data_source={'openai/gsm8k': 1.0} | ||
), | ||
env = RLEnvDataItem( | ||
rollout=RLRolloutResponseItem( | ||
response="<think>\nOkay, let's see. Natalia sold clips to 48 friends in April. Then in May, she sold half as many. So first, I need to figure out how many she sold in May. Half of 48 is 24, right? Because 48 divided by 2 is 24. So in May, she sold 24 clips.\n\nNow, to find the total number of clips sold in both months, I need to add the number from April and May together. That would be 48 (April) plus 24 (May). Let me do the addition: 48 + 24. Hmm, 40 + 20 is 60, and 8 + 4 is 12. So 60 + 12 is 72. So altogether, she sold 72 clips.\n\nWait, let me check that again. 48 plus 24. Yes, 48 + 20 is 68, then plus 4 more is 72. Yep, that seems right. So the total is 72.\n</think>\n\nNatalia sold 48 clips in April. In May, she sold half as many, which is 48 ÷ 2 = 24 clips. Adding both months together: 48 + 24 = 72. \n\n#### 72<|im_end|>", | ||
) | ||
) | ||
) | ||
FAKE_JUDGER_INPUT_ITEM_1 = copy.deepcopy(FAKE_JUDGER_INPUT_ITEM) | ||
FAKE_JUDGER_INPUT_ITEM_1.uid.observation_id = uuid4().int | ||
FAKE_JUDGER_INPUT_ITEM_MULTI_DATA = [FAKE_JUDGER_INPUT_ITEM, FAKE_JUDGER_INPUT_ITEM_1] # 用action_id来标识是不同的输入数据 | ||
FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE = copy.deepcopy(FAKE_JUDGER_INPUT_ITEM) | ||
FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE['data_source'] = {'openai/gsm8k-1': 0.5, 'openai/gsm8k-2': 0.5} | ||
FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE.data.data_source = {'openai/gsm8k-1': 0.5, 'openai/gsm8k-2': 0.5} | ||
|
||
def construct_judger_data(data_path): | ||
dataitem = [] | ||
with open(data_path, 'r', encoding='utf-8') as f: | ||
for line_num, line in enumerate(f, 1): | ||
# 去除行尾的空白字符并解析JSON | ||
data = json.loads(line.strip()) | ||
data_item = RLTextDataItem( | ||
messages=data['input'], | ||
reward_model={"ground_truth": data["gts"]}, | ||
response_str=data["output"], | ||
data_source={"openai/gsm8k": 1.0} | ||
data_item = RLDataFlowItem( | ||
uid = RLUIDItem( | ||
action_id=uuid4().int, | ||
observation_id=uuid4().int | ||
), | ||
data = RLDatasetItem( | ||
messages=[{ | ||
'role': 'user', | ||
'content': data["input"][5:-11] | ||
}], | ||
reward_model={"ground_truth": data["gts"]}, | ||
data_source={"openai/gsm8k": 1.0} | ||
), | ||
env = RLEnvDataItem( | ||
rollout=RLRolloutResponseItem(response=data['output']) | ||
) | ||
) | ||
dataitem.append(data_item) | ||
return dataitem | ||
|
@@ -74,43 +88,44 @@ def tearDownClass(cls): | |
|
||
def test_gsm8k_judger(self): | ||
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig | ||
gsm8k_judger_config = GSM8KJudgerConfig() | ||
gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") | ||
judger_cfg = JudgerConfig( | ||
reward_judger_configs={"openai/gsm8k": gsm8k_judger_config} | ||
reward_judger_configs=[gsm8k_judger_config] | ||
) | ||
judger_controller = JudgerController.remote(judger_cfg) | ||
res1 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM)) | ||
self.assertEqual(res1["reward"], 1.0) | ||
judger_controller = JudgerController.remote(judger_cfg) | ||
# 返回的形式为:RLJudgerResponseItem(uid=112750990920317762694895938380669501546, reward={'openai/gsm8k': 1}, extra_info={}) | ||
res1 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM)) | ||
self.assertEqual(res1.reward["openai/gsm8k"], 1.0) | ||
res2 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM_MULTI_DATA)) | ||
self.assertEqual(res2[0]["reward"], 1.0) | ||
self.assertEqual(res2[1]["reward"], 1.0) | ||
self.assertEqual(res2[0].reward["openai/gsm8k"], 1.0) | ||
self.assertEqual(res2[1].reward["openai/gsm8k"], 1.0) | ||
|
||
def test_gsm8k_multi_judger(self): | ||
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig | ||
gsm8k_judger_config_1 = GSM8KJudgerConfig() | ||
gsm8k_judger_config_2 = GSM8KJudgerConfig() | ||
# 支持一个GSM8KJudgerConfig创建多个实例 | ||
gsm8k_judger_config_1 = GSM8KJudgerConfig(judger_name="openai/gsm8k-1") | ||
gsm8k_judger_config_2 = GSM8KJudgerConfig(judger_name="openai/gsm8k-2") | ||
judger_cfg = JudgerConfig( | ||
reward_judger_configs={ | ||
"openai/gsm8k-1": gsm8k_judger_config_1, | ||
"openai/gsm8k-2": gsm8k_judger_config_2,} | ||
reward_judger_configs=[ | ||
gsm8k_judger_config_1, | ||
gsm8k_judger_config_2 | ||
] | ||
) | ||
judger_controller = JudgerController.remote(judger_cfg) | ||
res3 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE)) | ||
self.assertEqual(res3["reward"], 1.0) | ||
self.assertEqual(res3.reward["weighted_reward"], 1.0) # weighted_reward为固定字段,表示加权后的reward | ||
|
||
def test_gsm8k_judger_score(self): | ||
"""Test the judger functionality with single and multiple data sources.""" | ||
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig | ||
gsm8k_judger_config = GSM8KJudgerConfig() | ||
gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") | ||
judger_cfg = JudgerConfig( | ||
reward_judger_configs={"openai/gsm8k": gsm8k_judger_config} | ||
reward_judger_configs=[gsm8k_judger_config] | ||
) | ||
judger_controller = JudgerController.remote(judger_cfg) | ||
judger_data = construct_judger_data(VERL_ROLLOUT_DATA_PATH) | ||
group_data = ray.get(judger_controller.run.remote(judger_data)) | ||
reward = [] | ||
for data in group_data: | ||
reward.append(data["reward"]) | ||
reward = [data.reward["weighted_reward"] for data in group_data] | ||
avg_score = np.mean(reward) | ||
verl_score = 0.2418 | ||
self.assertLessEqual(float(np.abs(avg_score - verl_score)), 0.001) | ||
|
@@ -121,15 +136,14 @@ def test_gsm8k_remote_judger(self): | |
server = JudgerServer(port=8018) | ||
server.start() | ||
|
||
remote_judger_config = GSM8KRemoteJudgerConfig(remote_url=server.url) | ||
remote_judger_config = GSM8KRemoteJudgerConfig(judger_name="openai/gsm8k", remote_url=server.url) | ||
judger_cfg = JudgerConfig( | ||
reward_judger_configs={"openai/gsm8k": remote_judger_config} | ||
reward_judger_configs=[remote_judger_config] | ||
) | ||
judger_controller = JudgerController.remote(judger_cfg) | ||
judger_data = construct_judger_data(VERL_ROLLOUT_DATA_PATH) | ||
group_data = ray.get(judger_controller.run.remote(judger_data)) | ||
|
||
reward = [data["reward"] for data in group_data] | ||
reward = [data.reward["reward"] for data in group_data] | ||
avg_score = np.mean(reward) | ||
verl_score = 0.2418 | ||
self.assertLessEqual(float(np.abs(avg_score - verl_score)), 0.001) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
judge 需要支持 dict 格式或者说 dataclass 返回,而不能仅仅是一个 score
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
或者说任何地方返回都最好弄成 dataclass 或者 dict 格式 ,方便后续扩展