Skip to content

Conversation

YanhuiDua
Copy link
Collaborator

本PR的主要内容为重构dataflow中的数据结构,以实现更清晰、模块化的数据流,所有相关的数据结构定义均位于文件:xtuner/v1/data_proto/rl_data.py

RLDataFlowItem为Dataflow与Env中流转的数据结构,dataflow.run()最终返回为RLDataFlowItem;
其构成如下:

  • uid (RLUIDItem): 唯一标识符,用于追踪数据的生命周期和来源。
  • data (RLDatasetItem): 原始数据,通常来自数据集的输出。
  • env (RLEnvDataItem): 环境(Environment)在各个阶段产生的数据的集合。
  • extra_info (RLExtraDataItem): 预留的额外信息字段,方便用户自定义和扩展。

RLEnvDataItem 负责收集和组织 Environment 内部各个环节的输出,其具体包含:

  • rollout (RLRolloutResponseItem): rollout 阶段的输出,如模型的生成文本、token ID 等。
  • judger (RLJudgerResponseItem): judger 阶段的输出,如奖励分数、评估结果等。
  • agent (RLAgentDataItem): agent 阶段的输出(注意:此部分结构目前暂未完全定义)。
  • extra_info (Dict): 预留的额外信息字段。

其他数据结构可见代码定义。

其他修改:

  • 初始化需要传入judger_name, 用于区分不同judger返回的分数;同时,在judger中根据加权计算得到“weighted_reward”,为最终的reward分数

@YanhuiDua YanhuiDua requested review from hhaAndroid and Copilot and removed request for hhaAndroid September 18, 2025 14:50
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the data structure in the RL dataflow to implement a clearer, more modular data flow architecture. All data structure definitions are consolidated in the file xtuner/v1/data_proto/rl_data.py. The main change introduces RLDataFlowItem as the central data structure that flows between Dataflow and Environment components, replacing the previous RLTextDataItem dictionary-based approach.

Key changes include:

  • Unified data structure: Replaces dictionary-based data handling with structured Pydantic models for better type safety and clarity
  • Judger name requirement: Adds judger_name parameter to distinguish different judger types and calculate weighted rewards
  • Enhanced rollout data: Adds support for token IDs, logprobs, and other rollout metadata in the response structure

Reviewed Changes

Copilot reviewed 24 out of 24 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
xtuner/v1/data_proto/rl_data.py Defines all new RL data structures including RLDataFlowItem, RLUIDItem, RLDatasetItem, etc.
xtuner/v1/train/rl_trainer.py Updates trainer to use new data structure fields for accessing messages, rewards, and responses
xtuner/v1/ray/judger/controller.py Refactors judger controller to work with new data structures and support weighted rewards
xtuner/v1/ray/rollout/worker.py Updates rollout worker to return RLRolloutResponseItem with enhanced metadata
xtuner/v1/ray/environment/single_turn_env.py Modifies environment to use new data flow structures and update mechanisms
xtuner/v1/ray/dataflow/replay_buffer.py Significantly refactors replay buffer to work with new data structures

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

observation_versions.append(version)
group_states.append(item.env.rollout.finish_reason)

state_str = "paused" if "paused" in group_states else "returned"
Copy link
Preview

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The finish_reason check for 'paused' state may fail since typical finish reasons are 'stop', 'length', 'tool_call', etc. The condition 'paused' in group_states will likely never be true, causing all items to be marked as 'returned' regardless of their actual state.

Suggested change
state_str = "paused" if "paused" in group_states else "returned"
# If any finish_reason is not 'stop', consider it 'paused' (partial rollout), else 'returned'
state_str = "paused" if any(reason != "stop" for reason in group_states) else "returned"

Copilot uses AI. Check for mistakes.

last_trajectory = last_trajectory + delta_content if delta_content else last_trajectory
last_token_id = chunk_data["choices"][0]["delta"].get("gen_tokens")
if last_token_id is not None:
self.logger.info(f"{chunk_data['choices'][0]['delta']}")
Copy link
Preview

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging at INFO level inside a tight loop during streaming response processing can significantly impact performance. Consider using DEBUG level or removing this log statement in production code.

Suggested change
self.logger.info(f"{chunk_data['choices'][0]['delta']}")
self.logger.debug(f"{chunk_data['choices'][0]['delta']}")

Copilot uses AI. Check for mistakes.


flat_results = list(flatten(results))
assert len(flat_results) == len(group_data_item) * len(active_judgers), (
f"Expected {len(group_data_item) * len(active_judgers)} results, but got {len(flat_results)}"
Copy link
Preview

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message could be more helpful by including the actual values. Consider adding the specific counts: f'Expected {len(group_data_item)} * {len(active_judgers)} = {len(group_data_item) * len(active_judgers)} results, but got {len(flat_results)}'

Suggested change
f"Expected {len(group_data_item) * len(active_judgers)} results, but got {len(flat_results)}"
f"Expected {len(group_data_item)} * {len(active_judgers)} = {len(group_data_item) * len(active_judgers)} results, but got {len(flat_results)}"

Copilot uses AI. Check for mistakes.


def _default_postprocess(self, result: Any) -> Any:
assert len(data_item) == 1, "Default preprocess only supports single data item."
# todo: 支持api server来计算batch_reward
Copy link
Preview

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# TODO: Support batch reward calculation via API server'

Suggested change
# todo: 支持api server来计算batch_reward
# TODO: Support batch reward calculation via API server

Copilot uses AI. Check for mistakes.

}

def _default_postprocess(self, result: Any) -> List[RLJudgerResponseItem]:
## 将结果包装成 RLJudgerResponseItem
Copy link
Preview

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# Wrap results into RLJudgerResponseItem'

Suggested change
## 将结果包装成 RLJudgerResponseItem
## Wrap results into RLJudgerResponseItem

Copilot uses AI. Check for mistakes.

action_id = replay_meta.action_id
state_str = replay_meta.state

# 记录没完成rollut的group_id,用于下次续roll
Copy link
Preview

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a typo in the Chinese comment 'rollut' should be 'rollout'. Also consider translating to English: '# Record unfinished rollout group_id for continuation in next roll'

Suggested change
# 记录没完成rollut的group_id,用于下次续roll
# Record unfinished rollout group_id for continuation in next roll

Copilot uses AI. Check for mistakes.

Comment on lines 202 to 212
# 记录没完成rollut的group_id,用于下次续roll
if state_str == "paused":
self._paused.append(action_id)
elif state_str == "returned":
self._returned.append(action_id)

# grpo算法下,一个prompt是一个action-id,如果prompt发生了变化,那就是新的action_id
# 一个prompt的不同回答对应不同的observation_id
# 多轮的情况下:当prompt发生变化,则会有新的action_id,通过root_id标识数据的最初来源

# action相关
Copy link
Preview

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments are in Chinese while the rest of the codebase uses English. Consider translating these comments to English for consistency.

Suggested change
# 记录没完成rollut的group_id,用于下次续roll
if state_str == "paused":
self._paused.append(action_id)
elif state_str == "returned":
self._returned.append(action_id)
# grpo算法下,一个prompt是一个action-id,如果prompt发生了变化,那就是新的action_id
# 一个prompt的不同回答对应不同的observation_id
# 多轮的情况下:当prompt发生变化,则会有新的action_id,通过root_id标识数据的最初来源
# action相关
# Record group_ids of unfinished rollouts for continuation in the next roll
if state_str == "paused":
self._paused.append(action_id)
elif state_str == "returned":
self._returned.append(action_id)
# In the GRPO algorithm, each prompt corresponds to an action_id; if the prompt changes, a new action_id is generated
# Different responses to the same prompt correspond to different observation_ids
# In multi-turn scenarios: when the prompt changes, a new action_id is generated; root_id identifies the original source of the data
# Action related

Copilot uses AI. Check for mistakes.

Comment on lines +50 to +51
# 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦
# 每个模块返回独立的data item, 在env中进行更新
Copy link
Preview

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments are in Chinese while the rest of the codebase uses English. Consider translating to English: '# Transform input data in env to support rollout_controller as standalone rollout engine, decoupling modules' and '# Each module returns independent data item, updated in env'

Suggested change
# 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦
# 每个模块返回独立的data item, 在env中进行更新
# Transform input data in env to support rollout_controller as a standalone rollout engine, decoupling modules
# Each module returns an independent data item, which is updated in env

Copilot uses AI. Check for mistakes.

"""
group_samples = group_samples_for_retry
try:
# 该函数中所有的数据结构都是RLDataFlowItem
Copy link
Preview

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# All data structures in this function are RLDataFlowItem'

Suggested change
# 该函数中所有的数据结构都是RLDataFlowItem
# All data structures in this function are RLDataFlowItem

Copilot uses AI. Check for mistakes.

if not input_list:
return group_data_item[0]
return group_data_item
# 这里认为所有数据的data_source是一样的
Copy link
Preview

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is in Chinese while the rest of the codebase uses English. Consider translating to English: '# Assume all data have the same data_source'

Suggested change
# 这里认为所有数据的data_source是一样的
# Assume all data have the same data_source

Copilot uses AI. Check for mistakes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant