-
Notifications
You must be signed in to change notification settings - Fork 372
[refactor] refactor rl data structure in dataflow #1110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the data structure in the RL dataflow to implement a clearer, more modular data flow architecture. All data structure definitions are consolidated in the file xtuner/v1/data_proto/rl_data.py
. The main change introduces RLDataFlowItem
as the central data structure that flows between Dataflow and Environment components, replacing the previous RLTextDataItem
dictionary-based approach.
Key changes include:
- Unified data structure: Replaces dictionary-based data handling with structured Pydantic models for better type safety and clarity
- Judger name requirement: Adds
judger_name
parameter to distinguish different judger types and calculate weighted rewards - Enhanced rollout data: Adds support for token IDs, logprobs, and other rollout metadata in the response structure
Reviewed Changes
Copilot reviewed 24 out of 24 changed files in this pull request and generated 10 comments.
Show a summary per file
File | Description |
---|---|
xtuner/v1/data_proto/rl_data.py |
Defines all new RL data structures including RLDataFlowItem, RLUIDItem, RLDatasetItem, etc. |
xtuner/v1/train/rl_trainer.py |
Updates trainer to use new data structure fields for accessing messages, rewards, and responses |
xtuner/v1/ray/judger/controller.py |
Refactors judger controller to work with new data structures and support weighted rewards |
xtuner/v1/ray/rollout/worker.py |
Updates rollout worker to return RLRolloutResponseItem with enhanced metadata |
xtuner/v1/ray/environment/single_turn_env.py |
Modifies environment to use new data flow structures and update mechanisms |
xtuner/v1/ray/dataflow/replay_buffer.py |
Significantly refactors replay buffer to work with new data structures |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
xtuner/v1/ray/rollout/worker.py
Outdated
last_trajectory = last_trajectory + delta_content if delta_content else last_trajectory | ||
last_token_id = chunk_data["choices"][0]["delta"].get("gen_tokens") | ||
if last_token_id is not None: | ||
self.logger.info(f"{chunk_data['choices'][0]['delta']}") |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
||
flat_results = list(flatten(results)) | ||
assert len(flat_results) == len(group_data_item) * len(active_judgers), ( | ||
f"Expected {len(group_data_item) * len(active_judgers)} results, but got {len(flat_results)}" |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
xtuner/v1/ray/judger/native.py
Outdated
|
||
def _default_postprocess(self, result: Any) -> Any: | ||
assert len(data_item) == 1, "Default preprocess only supports single data item." | ||
# todo: 支持api server来计算batch_reward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# TODO: Support batch reward calculation via API server'
# todo: 支持api server来计算batch_reward | |
# TODO: Support batch reward calculation via API server |
Copilot uses AI. Check for mistakes.
} | ||
|
||
def _default_postprocess(self, result: Any) -> List[RLJudgerResponseItem]: | ||
## 将结果包装成 RLJudgerResponseItem |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# Wrap results into RLJudgerResponseItem'
## 将结果包装成 RLJudgerResponseItem | |
## Wrap results into RLJudgerResponseItem |
Copilot uses AI. Check for mistakes.
action_id = replay_meta.action_id | ||
state_str = replay_meta.state | ||
|
||
# 记录没完成rollut的group_id,用于下次续roll |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo in the Chinese comment 'rollut' should be 'rollout'. Also consider translating to English: '# Record unfinished rollout group_id for continuation in next roll'
# 记录没完成rollut的group_id,用于下次续roll | |
# Record unfinished rollout group_id for continuation in next roll |
Copilot uses AI. Check for mistakes.
# 记录没完成rollut的group_id,用于下次续roll | ||
if state_str == "paused": | ||
self._paused.append(action_id) | ||
elif state_str == "returned": | ||
self._returned.append(action_id) | ||
|
||
# grpo算法下,一个prompt是一个action-id,如果prompt发生了变化,那就是新的action_id | ||
# 一个prompt的不同回答对应不同的observation_id | ||
# 多轮的情况下:当prompt发生变化,则会有新的action_id,通过root_id标识数据的最初来源 | ||
|
||
# action相关 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments are in Chinese while the rest of the codebase uses English. Consider translating these comments to English for consistency.
# 记录没完成rollut的group_id,用于下次续roll | |
if state_str == "paused": | |
self._paused.append(action_id) | |
elif state_str == "returned": | |
self._returned.append(action_id) | |
# grpo算法下,一个prompt是一个action-id,如果prompt发生了变化,那就是新的action_id | |
# 一个prompt的不同回答对应不同的observation_id | |
# 多轮的情况下:当prompt发生变化,则会有新的action_id,通过root_id标识数据的最初来源 | |
# action相关 | |
# Record group_ids of unfinished rollouts for continuation in the next roll | |
if state_str == "paused": | |
self._paused.append(action_id) | |
elif state_str == "returned": | |
self._returned.append(action_id) | |
# In the GRPO algorithm, each prompt corresponds to an action_id; if the prompt changes, a new action_id is generated | |
# Different responses to the same prompt correspond to different observation_ids | |
# In multi-turn scenarios: when the prompt changes, a new action_id is generated; root_id identifies the original source of the data | |
# Action related |
Copilot uses AI. Check for mistakes.
# 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦 | ||
# 每个模块返回独立的data item, 在env中进行更新 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments are in Chinese while the rest of the codebase uses English. Consider translating to English: '# Transform input data in env to support rollout_controller as standalone rollout engine, decoupling modules' and '# Each module returns independent data item, updated in env'
# 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦 | |
# 每个模块返回独立的data item, 在env中进行更新 | |
# Transform input data in env to support rollout_controller as a standalone rollout engine, decoupling modules | |
# Each module returns an independent data item, which is updated in env |
Copilot uses AI. Check for mistakes.
""" | ||
group_samples = group_samples_for_retry | ||
try: | ||
# 该函数中所有的数据结构都是RLDataFlowItem |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# All data structures in this function are RLDataFlowItem'
# 该函数中所有的数据结构都是RLDataFlowItem | |
# All data structures in this function are RLDataFlowItem |
Copilot uses AI. Check for mistakes.
xtuner/v1/ray/judger/controller.py
Outdated
if not input_list: | ||
return group_data_item[0] | ||
return group_data_item | ||
# 这里认为所有数据的data_source是一样的 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# Assume all data have the same data_source'
# 这里认为所有数据的data_source是一样的 | |
# Assume all data have the same data_source |
Copilot uses AI. Check for mistakes.
assert len(sample_params["stops"]) == 1 | ||
last_trajectory += sample_params["stops"][0] | ||
if len(last_token_ids) > 0: | ||
last_token_ids.append(sample_params["stop_token_ids"][0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是手动补的stop token,会导致logprob和token id不等长,缺少stop token的logprob
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lmdeploy 不会返回吗?还需要自己追加?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
哦哦,那还是需要直接用lmdeploy返回的stop_token,处理会更简单些,这个pr会支持使用lmdeploy返回的stop_token,谢谢
if logprobs_content is not None: | ||
for content_item in logprobs_content["content"]: | ||
last_logprobs.append(content_item["logprob"]) | ||
# todo(@duanyanhui): remove appending stop tokens manually after lmdeploy support return stop_token_ids. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个需求跟lmdeploy反馈过了吗,有预计的时间吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个lmdeploy已经支持了,需要再测试确认下
rollout: RLRolloutResponseItem = RLRolloutResponseItem() | ||
judger: RLJudgerResponseItem = RLJudgerResponseItem() | ||
agent: RLAgentDataItem = RLAgentDataItem() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的数据结构层级有点深,为什么要这么去分呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RLDataFlowItem中data为action, env为observation,所以将rollout/judger/agent组合为env,感觉这样更清晰一些
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的实现层级依据是和模块绑定的,一个模块有一个对应的 dataitem,方便理解.数据流容易区分
xtuner/v1/data_proto/rl_data.py
Outdated
""" | ||
|
||
env: str = "" | ||
root_id: int = 0 # designed for grpo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为何是 grpo 专用的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,也不是grpo专用,是想用来标识数据来源
class RLTextTokenizeFn(CachableTokenizeFunction[RLTextDataItem]): | ||
def __init__(self, tokenizer: PreTrainedTokenizer, max_length: int | None = None, *args, **kwargs): | ||
class RLTextTokenizeFn(CachableTokenizeFunction[RLDatasetItem]): | ||
def __init__(self, tokenizer: "PreTrainedTokenizer", max_length: int | None = None, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要加 "PreTrainedTokenizer"
) | ||
self.dataloader = build_dataloader( | ||
dataloader_config=config.dataloader_cfg, | ||
dataloader_config=self.fake_dataloader_cfg, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样 num_workers 无法修改了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, 那这里定义一个默认的吧,用户需要可以传入,不需要就使用默认的dataloader_config
assert len(sample_params["stops"]) == 1 | ||
last_trajectory += sample_params["stops"][0] | ||
if len(last_token_ids) > 0: | ||
last_token_ids.append(sample_params["stop_token_ids"][0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lmdeploy 不会返回吗?还需要自己追加?
dict: A dictionary containing the accuracy score. | ||
""" | ||
return {"accuracy": sum(s["reward"] > 0 for s in samples) / len(samples)} | ||
return {"accuracy": sum(s.env.judger.reward["weighted_reward"] > 0 for s in samples) / len(samples)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不同的评测集会有不同的metric计算方式,这个compute_metric现在只能设置一个,不太灵活
pending_tasks.add(retry_task) | ||
else: | ||
self.logger.error(f"Max retry reached for {result['prompt_id']}. Not retrying.") | ||
self.logger.error(f"Max retry reached for {result.uid.action_id}. Not retrying.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里如果不self.return_list.append(result)的话,训练会直接挂掉,我现在是往里放了个reward为0的空回复去解决的,看看有没有什么更好的方法?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果这边失败的请求不放在return_list里面,是不是返回的总样本会变少啊,这样的话,计算metric的时候取平均就没有统计到失败样本,acc就会偏高
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请求失败的样本按道理不应该算metric,如果直接认为它的metric为负,也不太合理
""" | ||
|
||
model_config = ConfigDict(extra="forbid") | ||
response: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前的这版数据格式不支持多轮对话的RL
之前的版本是强行把response_str传了个message来支持的,现在用pydantic的话强制类型就没办法训多轮对话了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有具体例子吗?我理解 RLRolloutResponseItem 面对的只是一次 rollout 的 response,如果是多轮调用的话,是需要把这个对象转换为 message 训练吧?现在这样会报错吗?
97c5108
to
c40e5c5
Compare
) | ||
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig | ||
gsm8k_judger_config = GSM8KJudgerConfig() | ||
gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
judge 需要支持 dict 格式或者说 dataclass 返回,而不能仅仅是一个 score
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
或者说任何地方返回都最好弄成 dataclass 或者 dict 格式 ,方便后续扩展
def test_lmdeploy_evaluator(self): | ||
def custom_compute_metric(samples): | ||
return {"custom_accuracy": sum(s["reward"] > 0 for s in samples) / len(samples)} | ||
return {"custom_accuracy": sum(s.env.judger.reward["weighted_reward"] > 0 for s in samples) / len(samples)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
返回字典后,这个地方就简单了
FAKE_JUDGER_INPUT_ITEM = copy.deepcopy(FAKE_INPUT_DATA_ITEM) | ||
FAKE_JUDGER_INPUT_ITEM["response_str"] = "<think>\nOkay, let's see. Natalia sold clips to 48 friends in April. Then in May, she sold half as many. So first, I need to figure out how many she sold in May. Half of 48 is 24, right? Because 48 divided by 2 is 24. So in May, she sold 24 clips.\n\nNow, to find the total number of clips sold in both months, I need to add the number from April and May together. That would be 48 (April) plus 24 (May). Let me do the addition: 48 + 24. Hmm, 40 + 20 is 60, and 8 + 4 is 12. So 60 + 12 is 72. So altogether, she sold 72 clips.\n\nWait, let me check that again. 48 plus 24. Yes, 48 + 20 is 68, then plus 4 more is 72. Yep, that seems right. So the total is 72.\n</think>\n\nNatalia sold 48 clips in April. In May, she sold half as many, which is 48 ÷ 2 = 24 clips. Adding both months together: 48 + 24 = 72. \n\n#### 72" | ||
FAKE_JUDGER_INPUT_ITEM_MULTI_DATA = [FAKE_JUDGER_INPUT_ITEM] * 2 | ||
FAKE_JUDGER_INPUT_ITEM = RLDataFlowItem( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是支持传入一个大的 dict,然后自动转 RLDataFlowItem? 而不需要每个对象都要构建
rollout: RLRolloutResponseItem = RLRolloutResponseItem() | ||
judger: RLJudgerResponseItem = RLJudgerResponseItem() | ||
agent: RLAgentDataItem = RLAgentDataItem() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的实现层级依据是和模块绑定的,一个模块有一个对应的 dataitem,方便理解.数据流容易区分
ability: Optional[str] = None | ||
reward_model: Optional[Dict[str, Any]] = None | ||
data_source: Optional[Dict[str, Any]] = None | ||
extra_info: Dict[str, Any] = dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tool 相关的是放到 extra_info 是吧?
""" | ||
|
||
model_config = ConfigDict(extra="forbid") | ||
response: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有具体例子吗?我理解 RLRolloutResponseItem 面对的只是一次 rollout 的 response,如果是多轮调用的话,是需要把这个对象转换为 message 训练吧?现在这样会报错吗?
|
||
model_config = ConfigDict(extra="forbid") | ||
uid: Optional[int] = None | ||
reward: Dict[str, Any] = dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以预定义几个 key,例如用于算 reward 的叫做 score,用于算 val 的叫做 acc,也支持用户自定义
# ============================================== | ||
|
||
|
||
class SampleParams(BaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我发现不同的推理后端,有些参数名是不一样的或者默认值含义不一样,需要有一个地方转换。
# 说明: 这里没定义API server情况数据格式,因为直接使用openai server的格式 | ||
class RLRolloutRequestItem(BaseModel): | ||
messages: List[Dict[str, Any]] | ||
tools: List = Field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了避免 lmdeploy 显示的那个信息,可以设置默认值是 none
for item in grouped_dataitem: | ||
version = item.uid.version | ||
observation_ids.append(item.uid.observation_id) | ||
observation_refs.append(ray.put(item.env)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得这个开销还是蛮大的。我有点没有理解
- 为啥需要一个 ReplayBuffer 作为中转,而不是直接在 dataflow 里面管理
- 如果要一个额外对象 ReplayBuffer,而是一定要开一个新的 ray remote,这样会出现 flow 里面 put,然后在其余地方 get 的进程开销,当轮数很多,数据很多,可以实测下这个时间。
本PR的主要内容为重构dataflow中的数据结构,以实现更清晰、模块化的数据流,所有相关的数据结构定义均位于文件:
xtuner/v1/data_proto/rl_data.py
。RLDataFlowItem为Dataflow与Env中流转的数据结构,dataflow.run()最终返回为RLDataFlowItem;
其构成如下:
uid
(RLUIDItem
): 唯一标识符,用于追踪数据的生命周期和来源。data
(RLDatasetItem
): 原始数据,通常来自数据集的输出。env
(RLEnvDataItem
): 环境(Environment)在各个阶段产生的数据的集合。extra_info
(RLExtraDataItem
): 预留的额外信息字段,方便用户自定义和扩展。RLEnvDataItem
负责收集和组织Environment
内部各个环节的输出,其具体包含:rollout
(RLRolloutResponseItem
):rollout
阶段的输出,如模型的生成文本、token ID 等。judger
(RLJudgerResponseItem
):judger
阶段的输出,如奖励分数、评估结果等。agent
(RLAgentDataItem
):agent
阶段的输出(注意:此部分结构目前暂未完全定义)。extra_info
(Dict
): 预留的额外信息字段。其他数据结构可见代码定义。
其他修改: