-
Notifications
You must be signed in to change notification settings - Fork 367
[refactor] refactor rl data structure in dataflow #1110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the data structure in the RL dataflow to implement a clearer, more modular data flow architecture. All data structure definitions are consolidated in the file xtuner/v1/data_proto/rl_data.py
. The main change introduces RLDataFlowItem
as the central data structure that flows between Dataflow and Environment components, replacing the previous RLTextDataItem
dictionary-based approach.
Key changes include:
- Unified data structure: Replaces dictionary-based data handling with structured Pydantic models for better type safety and clarity
- Judger name requirement: Adds
judger_name
parameter to distinguish different judger types and calculate weighted rewards - Enhanced rollout data: Adds support for token IDs, logprobs, and other rollout metadata in the response structure
Reviewed Changes
Copilot reviewed 24 out of 24 changed files in this pull request and generated 10 comments.
Show a summary per file
File | Description |
---|---|
xtuner/v1/data_proto/rl_data.py |
Defines all new RL data structures including RLDataFlowItem, RLUIDItem, RLDatasetItem, etc. |
xtuner/v1/train/rl_trainer.py |
Updates trainer to use new data structure fields for accessing messages, rewards, and responses |
xtuner/v1/ray/judger/controller.py |
Refactors judger controller to work with new data structures and support weighted rewards |
xtuner/v1/ray/rollout/worker.py |
Updates rollout worker to return RLRolloutResponseItem with enhanced metadata |
xtuner/v1/ray/environment/single_turn_env.py |
Modifies environment to use new data flow structures and update mechanisms |
xtuner/v1/ray/dataflow/replay_buffer.py |
Significantly refactors replay buffer to work with new data structures |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
observation_versions.append(version) | ||
group_states.append(item.env.rollout.finish_reason) | ||
|
||
state_str = "paused" if "paused" in group_states else "returned" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The finish_reason check for 'paused' state may fail since typical finish reasons are 'stop', 'length', 'tool_call', etc. The condition 'paused' in group_states will likely never be true, causing all items to be marked as 'returned' regardless of their actual state.
state_str = "paused" if "paused" in group_states else "returned" | |
# If any finish_reason is not 'stop', consider it 'paused' (partial rollout), else 'returned' | |
state_str = "paused" if any(reason != "stop" for reason in group_states) else "returned" |
Copilot uses AI. Check for mistakes.
xtuner/v1/ray/rollout/worker.py
Outdated
last_trajectory = last_trajectory + delta_content if delta_content else last_trajectory | ||
last_token_id = chunk_data["choices"][0]["delta"].get("gen_tokens") | ||
if last_token_id is not None: | ||
self.logger.info(f"{chunk_data['choices'][0]['delta']}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logging at INFO level inside a tight loop during streaming response processing can significantly impact performance. Consider using DEBUG level or removing this log statement in production code.
self.logger.info(f"{chunk_data['choices'][0]['delta']}") | |
self.logger.debug(f"{chunk_data['choices'][0]['delta']}") |
Copilot uses AI. Check for mistakes.
|
||
flat_results = list(flatten(results)) | ||
assert len(flat_results) == len(group_data_item) * len(active_judgers), ( | ||
f"Expected {len(group_data_item) * len(active_judgers)} results, but got {len(flat_results)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message could be more helpful by including the actual values. Consider adding the specific counts: f'Expected {len(group_data_item)} * {len(active_judgers)} = {len(group_data_item) * len(active_judgers)} results, but got {len(flat_results)}'
f"Expected {len(group_data_item) * len(active_judgers)} results, but got {len(flat_results)}" | |
f"Expected {len(group_data_item)} * {len(active_judgers)} = {len(group_data_item) * len(active_judgers)} results, but got {len(flat_results)}" |
Copilot uses AI. Check for mistakes.
xtuner/v1/ray/judger/native.py
Outdated
|
||
def _default_postprocess(self, result: Any) -> Any: | ||
assert len(data_item) == 1, "Default preprocess only supports single data item." | ||
# todo: 支持api server来计算batch_reward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# TODO: Support batch reward calculation via API server'
# todo: 支持api server来计算batch_reward | |
# TODO: Support batch reward calculation via API server |
Copilot uses AI. Check for mistakes.
} | ||
|
||
def _default_postprocess(self, result: Any) -> List[RLJudgerResponseItem]: | ||
## 将结果包装成 RLJudgerResponseItem |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# Wrap results into RLJudgerResponseItem'
## 将结果包装成 RLJudgerResponseItem | |
## Wrap results into RLJudgerResponseItem |
Copilot uses AI. Check for mistakes.
action_id = replay_meta.action_id | ||
state_str = replay_meta.state | ||
|
||
# 记录没完成rollut的group_id,用于下次续roll |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo in the Chinese comment 'rollut' should be 'rollout'. Also consider translating to English: '# Record unfinished rollout group_id for continuation in next roll'
# 记录没完成rollut的group_id,用于下次续roll | |
# Record unfinished rollout group_id for continuation in next roll |
Copilot uses AI. Check for mistakes.
# 记录没完成rollut的group_id,用于下次续roll | ||
if state_str == "paused": | ||
self._paused.append(action_id) | ||
elif state_str == "returned": | ||
self._returned.append(action_id) | ||
|
||
# grpo算法下,一个prompt是一个action-id,如果prompt发生了变化,那就是新的action_id | ||
# 一个prompt的不同回答对应不同的observation_id | ||
# 多轮的情况下:当prompt发生变化,则会有新的action_id,通过root_id标识数据的最初来源 | ||
|
||
# action相关 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments are in Chinese while the rest of the codebase uses English. Consider translating these comments to English for consistency.
# 记录没完成rollut的group_id,用于下次续roll | |
if state_str == "paused": | |
self._paused.append(action_id) | |
elif state_str == "returned": | |
self._returned.append(action_id) | |
# grpo算法下,一个prompt是一个action-id,如果prompt发生了变化,那就是新的action_id | |
# 一个prompt的不同回答对应不同的observation_id | |
# 多轮的情况下:当prompt发生变化,则会有新的action_id,通过root_id标识数据的最初来源 | |
# action相关 | |
# Record group_ids of unfinished rollouts for continuation in the next roll | |
if state_str == "paused": | |
self._paused.append(action_id) | |
elif state_str == "returned": | |
self._returned.append(action_id) | |
# In the GRPO algorithm, each prompt corresponds to an action_id; if the prompt changes, a new action_id is generated | |
# Different responses to the same prompt correspond to different observation_ids | |
# In multi-turn scenarios: when the prompt changes, a new action_id is generated; root_id identifies the original source of the data | |
# Action related |
Copilot uses AI. Check for mistakes.
# 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦 | ||
# 每个模块返回独立的data item, 在env中进行更新 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments are in Chinese while the rest of the codebase uses English. Consider translating to English: '# Transform input data in env to support rollout_controller as standalone rollout engine, decoupling modules' and '# Each module returns independent data item, updated in env'
# 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦 | |
# 每个模块返回独立的data item, 在env中进行更新 | |
# Transform input data in env to support rollout_controller as a standalone rollout engine, decoupling modules | |
# Each module returns an independent data item, which is updated in env |
Copilot uses AI. Check for mistakes.
""" | ||
group_samples = group_samples_for_retry | ||
try: | ||
# 该函数中所有的数据结构都是RLDataFlowItem |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# All data structures in this function are RLDataFlowItem'
# 该函数中所有的数据结构都是RLDataFlowItem | |
# All data structures in this function are RLDataFlowItem |
Copilot uses AI. Check for mistakes.
xtuner/v1/ray/judger/controller.py
Outdated
if not input_list: | ||
return group_data_item[0] | ||
return group_data_item | ||
# 这里认为所有数据的data_source是一样的 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# Assume all data have the same data_source'
# 这里认为所有数据的data_source是一样的 | |
# Assume all data have the same data_source |
Copilot uses AI. Check for mistakes.
本PR的主要内容为重构dataflow中的数据结构,以实现更清晰、模块化的数据流,所有相关的数据结构定义均位于文件:
xtuner/v1/data_proto/rl_data.py
。RLDataFlowItem为Dataflow与Env中流转的数据结构,dataflow.run()最终返回为RLDataFlowItem;
其构成如下:
uid
(RLUIDItem
): 唯一标识符,用于追踪数据的生命周期和来源。data
(RLDatasetItem
): 原始数据,通常来自数据集的输出。env
(RLEnvDataItem
): 环境(Environment)在各个阶段产生的数据的集合。extra_info
(RLExtraDataItem
): 预留的额外信息字段,方便用户自定义和扩展。RLEnvDataItem
负责收集和组织Environment
内部各个环节的输出,其具体包含:rollout
(RLRolloutResponseItem
):rollout
阶段的输出,如模型的生成文本、token ID 等。judger
(RLJudgerResponseItem
):judger
阶段的输出,如奖励分数、评估结果等。agent
(RLAgentDataItem
):agent
阶段的输出(注意:此部分结构目前暂未完全定义)。extra_info
(Dict
): 预留的额外信息字段。其他数据结构可见代码定义。
其他修改: