diff --git a/ci/scripts/test_dataflow.py b/ci/scripts/test_dataflow.py
index d0a65fd78..897cd5d85 100644
--- a/ci/scripts/test_dataflow.py
+++ b/ci/scripts/test_dataflow.py
@@ -48,9 +48,9 @@ def main():
tensor_parallel_size=8,
)
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
- gsm8k_judger_config = GSM8KJudgerConfig()
+ gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
judger_cfg = JudgerConfig(
- reward_judger_configs={"openai/gsm8k": gsm8k_judger_config}
+ reward_judger_configs=[gsm8k_judger_config]
)
dataflow_cfg = DataFlowConfig(
diff --git a/ci/scripts/test_grpo_trainer.py b/ci/scripts/test_grpo_trainer.py
index 2414baa9f..c5a2554e1 100644
--- a/ci/scripts/test_grpo_trainer.py
+++ b/ci/scripts/test_grpo_trainer.py
@@ -84,6 +84,9 @@ def main(args):
gpus_per_node=args.gpus_per_node, # gpu: 8, npu: 16
dtype="bfloat16",
skip_load_weights=False,
+ extra_rollout_config={
+ "lmdeploy_log_level": "CRITICAL",
+ }
)
dataflow_config = DataFlowConfig(
env="test",
@@ -93,9 +96,9 @@ def main(args):
sample_params=SampleParams(max_tokens=args.max_response_length),
)
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
- gsm8k_judger_config = GSM8KJudgerConfig()
+ gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
judger_cfg = JudgerConfig(
- reward_judger_configs={"openai/gsm8k": gsm8k_judger_config}
+ reward_judger_configs=[gsm8k_judger_config]
)
train_dataset_cfg = [
{
diff --git a/tests/ray/test_evaluator.py b/tests/ray/test_evaluator.py
index 4bbf7ac05..a8ca63fd5 100644
--- a/tests/ray/test_evaluator.py
+++ b/tests/ray/test_evaluator.py
@@ -42,9 +42,9 @@ def init_config(self):
tensor_parallel_size=8,
)
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
- gsm8k_judger_config = GSM8KJudgerConfig()
+ gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
self.judger_cfg = JudgerConfig(
- reward_judger_configs={"openai/gsm8k": gsm8k_judger_config}
+ reward_judger_configs=[gsm8k_judger_config]
)
self.eval_dataset_cfg = [
@@ -82,7 +82,7 @@ def tearDown(self):
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_evaluator(self):
def custom_compute_metric(samples):
- return {"custom_accuracy": sum(s["reward"] > 0 for s in samples) / len(samples)}
+ return {"custom_accuracy": sum(s.env.judger.reward["weighted_reward"] > 0 for s in samples) / len(samples)}
evaluator_cfg = EvaluatorConfig(
dataset_cfg=self.eval_dataset_cfg,
@@ -93,7 +93,6 @@ def custom_compute_metric(samples):
)
evaluator = Evaluator.remote(evaluator_cfg, self.test_env)
correctness = ray.get(evaluator.run.remote(sample_params=self.sample_params))
-
custom_evaluator_cfg = EvaluatorConfig(
dataset_cfg=self.eval_dataset_cfg,
tokenizer=self.tokenizer,
diff --git a/tests/ray/test_judger.py b/tests/ray/test_judger.py
index 255279810..b61a50b62 100644
--- a/tests/ray/test_judger.py
+++ b/tests/ray/test_judger.py
@@ -6,49 +6,63 @@
import ray
import unittest
import numpy as np
-
+from uuid import uuid4
from xtuner.v1.ray.environment import SingleTurnEnvironment
from xtuner.v1.ray.config.worker import RolloutConfig
from xtuner.v1.ray.accelerator import AcceleratorResourcesConfig, AutoAcceleratorWorkers
from xtuner.v1.ray.judger.controller import JudgerController, JudgerConfig
-from xtuner.v1.datasets.data_item import RLTextDataItem
+from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLDatasetItem, RLEnvDataItem, RLRolloutResponseItem, RLUIDItem
MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
DATA_PATH = os.environ["ROLLOUT_DATA_PATH"]
VERL_ROLLOUT_DATA_PATH = os.environ["VERL_ROLLOUT_DATA_PATH"]
-FAKE_INPUT_DATA_ITEM = {
- 'messages': [{
- 'role': 'user', 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"'
- }],
- 'num_tokens': 62,
- 'reward_model': {'ground_truth': '72', 'style': 'rule'},
- 'ability': 'math',
- 'data_source': {'openai/gsm8k': 1.0},
- 'extra_info': {'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72', 'index': 0, 'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'split': 'train', 'raw_prompt': '<|im_start|>user\nNatalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####".<|im_end|>\n<|im_start|>assistant\n'},
- 'env': 'test_env',
- 'group_id': 255971142656329732139546771377476227093,
- 'prompt_id': 22175756018538642401581407443664245296,
- 'retry_times': 0}
-
-FAKE_JUDGER_INPUT_ITEM = copy.deepcopy(FAKE_INPUT_DATA_ITEM)
-FAKE_JUDGER_INPUT_ITEM["response_str"] = "\nOkay, let's see. Natalia sold clips to 48 friends in April. Then in May, she sold half as many. So first, I need to figure out how many she sold in May. Half of 48 is 24, right? Because 48 divided by 2 is 24. So in May, she sold 24 clips.\n\nNow, to find the total number of clips sold in both months, I need to add the number from April and May together. That would be 48 (April) plus 24 (May). Let me do the addition: 48 + 24. Hmm, 40 + 20 is 60, and 8 + 4 is 12. So 60 + 12 is 72. So altogether, she sold 72 clips.\n\nWait, let me check that again. 48 plus 24. Yes, 48 + 20 is 68, then plus 4 more is 72. Yep, that seems right. So the total is 72.\n\n\nNatalia sold 48 clips in April. In May, she sold half as many, which is 48 ÷ 2 = 24 clips. Adding both months together: 48 + 24 = 72. \n\n#### 72"
-FAKE_JUDGER_INPUT_ITEM_MULTI_DATA = [FAKE_JUDGER_INPUT_ITEM] * 2
+FAKE_JUDGER_INPUT_ITEM = RLDataFlowItem(
+ uid = RLUIDItem(action_id=uuid4().int,
+ observation_id=uuid4().int),
+ data = RLDatasetItem(
+ messages=[{
+ 'role': 'user', 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"'
+ }],
+ num_tokens=62,
+ reward_model={'ground_truth': '72', 'style': 'rule'},
+ ability='math',
+ data_source={'openai/gsm8k': 1.0}
+ ),
+ env = RLEnvDataItem(
+ rollout=RLRolloutResponseItem(
+ response="\nOkay, let's see. Natalia sold clips to 48 friends in April. Then in May, she sold half as many. So first, I need to figure out how many she sold in May. Half of 48 is 24, right? Because 48 divided by 2 is 24. So in May, she sold 24 clips.\n\nNow, to find the total number of clips sold in both months, I need to add the number from April and May together. That would be 48 (April) plus 24 (May). Let me do the addition: 48 + 24. Hmm, 40 + 20 is 60, and 8 + 4 is 12. So 60 + 12 is 72. So altogether, she sold 72 clips.\n\nWait, let me check that again. 48 plus 24. Yes, 48 + 20 is 68, then plus 4 more is 72. Yep, that seems right. So the total is 72.\n\n\nNatalia sold 48 clips in April. In May, she sold half as many, which is 48 ÷ 2 = 24 clips. Adding both months together: 48 + 24 = 72. \n\n#### 72<|im_end|>",
+ )
+ )
+)
+FAKE_JUDGER_INPUT_ITEM_1 = copy.deepcopy(FAKE_JUDGER_INPUT_ITEM)
+FAKE_JUDGER_INPUT_ITEM_1.uid.observation_id = uuid4().int
+FAKE_JUDGER_INPUT_ITEM_MULTI_DATA = [FAKE_JUDGER_INPUT_ITEM, FAKE_JUDGER_INPUT_ITEM_1] # 用action_id来标识是不同的输入数据
FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE = copy.deepcopy(FAKE_JUDGER_INPUT_ITEM)
-FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE['data_source'] = {'openai/gsm8k-1': 0.5, 'openai/gsm8k-2': 0.5}
+FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE.data.data_source = {'openai/gsm8k-1': 0.5, 'openai/gsm8k-2': 0.5}
def construct_judger_data(data_path):
dataitem = []
with open(data_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
- # 去除行尾的空白字符并解析JSON
data = json.loads(line.strip())
- data_item = RLTextDataItem(
- messages=data['input'],
- reward_model={"ground_truth": data["gts"]},
- response_str=data["output"],
- data_source={"openai/gsm8k": 1.0}
+ data_item = RLDataFlowItem(
+ uid = RLUIDItem(
+ action_id=uuid4().int,
+ observation_id=uuid4().int
+ ),
+ data = RLDatasetItem(
+ messages=[{
+ 'role': 'user',
+ 'content': data["input"][5:-11]
+ }],
+ reward_model={"ground_truth": data["gts"]},
+ data_source={"openai/gsm8k": 1.0}
+ ),
+ env = RLEnvDataItem(
+ rollout=RLRolloutResponseItem(response=data['output'])
+ )
)
dataitem.append(data_item)
return dataitem
@@ -74,43 +88,44 @@ def tearDownClass(cls):
def test_gsm8k_judger(self):
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
- gsm8k_judger_config = GSM8KJudgerConfig()
+ gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
judger_cfg = JudgerConfig(
- reward_judger_configs={"openai/gsm8k": gsm8k_judger_config}
+ reward_judger_configs=[gsm8k_judger_config]
)
- judger_controller = JudgerController.remote(judger_cfg)
- res1 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM))
- self.assertEqual(res1["reward"], 1.0)
+ judger_controller = JudgerController.remote(judger_cfg)
+ # 返回的形式为:RLJudgerResponseItem(uid=112750990920317762694895938380669501546, reward={'openai/gsm8k': 1}, extra_info={})
+ res1 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM))
+ self.assertEqual(res1.reward["openai/gsm8k"], 1.0)
res2 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM_MULTI_DATA))
- self.assertEqual(res2[0]["reward"], 1.0)
- self.assertEqual(res2[1]["reward"], 1.0)
+ self.assertEqual(res2[0].reward["openai/gsm8k"], 1.0)
+ self.assertEqual(res2[1].reward["openai/gsm8k"], 1.0)
def test_gsm8k_multi_judger(self):
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
- gsm8k_judger_config_1 = GSM8KJudgerConfig()
- gsm8k_judger_config_2 = GSM8KJudgerConfig()
+ # 支持一个GSM8KJudgerConfig创建多个实例
+ gsm8k_judger_config_1 = GSM8KJudgerConfig(judger_name="openai/gsm8k-1")
+ gsm8k_judger_config_2 = GSM8KJudgerConfig(judger_name="openai/gsm8k-2")
judger_cfg = JudgerConfig(
- reward_judger_configs={
- "openai/gsm8k-1": gsm8k_judger_config_1,
- "openai/gsm8k-2": gsm8k_judger_config_2,}
+ reward_judger_configs=[
+ gsm8k_judger_config_1,
+ gsm8k_judger_config_2
+ ]
)
judger_controller = JudgerController.remote(judger_cfg)
res3 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE))
- self.assertEqual(res3["reward"], 1.0)
+ self.assertEqual(res3.reward["weighted_reward"], 1.0) # weighted_reward为固定字段,表示加权后的reward
def test_gsm8k_judger_score(self):
"""Test the judger functionality with single and multiple data sources."""
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
- gsm8k_judger_config = GSM8KJudgerConfig()
+ gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
judger_cfg = JudgerConfig(
- reward_judger_configs={"openai/gsm8k": gsm8k_judger_config}
+ reward_judger_configs=[gsm8k_judger_config]
)
judger_controller = JudgerController.remote(judger_cfg)
judger_data = construct_judger_data(VERL_ROLLOUT_DATA_PATH)
group_data = ray.get(judger_controller.run.remote(judger_data))
- reward = []
- for data in group_data:
- reward.append(data["reward"])
+ reward = [data.reward["weighted_reward"] for data in group_data]
avg_score = np.mean(reward)
verl_score = 0.2418
self.assertLessEqual(float(np.abs(avg_score - verl_score)), 0.001)
@@ -121,15 +136,14 @@ def test_gsm8k_remote_judger(self):
server = JudgerServer(port=8018)
server.start()
- remote_judger_config = GSM8KRemoteJudgerConfig(remote_url=server.url)
+ remote_judger_config = GSM8KRemoteJudgerConfig(judger_name="openai/gsm8k", remote_url=server.url)
judger_cfg = JudgerConfig(
- reward_judger_configs={"openai/gsm8k": remote_judger_config}
+ reward_judger_configs=[remote_judger_config]
)
judger_controller = JudgerController.remote(judger_cfg)
judger_data = construct_judger_data(VERL_ROLLOUT_DATA_PATH)
group_data = ray.get(judger_controller.run.remote(judger_data))
-
- reward = [data["reward"] for data in group_data]
+ reward = [data.reward["reward"] for data in group_data]
avg_score = np.mean(reward)
verl_score = 0.2418
self.assertLessEqual(float(np.abs(avg_score - verl_score)), 0.001)
diff --git a/tests/ray/test_rollout.py b/tests/ray/test_rollout.py
index d4fa1357f..0e93819e8 100644
--- a/tests/ray/test_rollout.py
+++ b/tests/ray/test_rollout.py
@@ -59,9 +59,9 @@ def init_config(self):
dtype="bfloat16",
)
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
- gsm8k_judger_config = GSM8KJudgerConfig()
+ gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
self.judger_cfg = JudgerConfig(
- reward_judger_configs={"openai/gsm8k": gsm8k_judger_config}
+ reward_judger_configs=[gsm8k_judger_config]
)
self.dataflow_cfg = DataFlowConfig(
env="test",
@@ -109,8 +109,9 @@ def test_lmdeploy_generate(self):
)
sample_params = SampleParams(temperature=0.0)
rollout_controller = RolloutController.remote(self.rollout_cfg, rollout_workers_map) # type: ignore[attr-defined]
- res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))
-
+ extra_params = {"logprobs": True, "top_logprobs": 1, "return_token_ids": True}
+ res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params, extra_params=extra_params))
+
api_url = "http://localhost:8000/v1/chat/completions"
headers = {"Content-Type": "application/json"}
payload = {
@@ -126,8 +127,8 @@ def test_lmdeploy_generate(self):
self.assertEqual(res1.finish_reason, "stop")
self.assertEqual(response_data["finish_reason"], "stop")
self.assertEqual(res1.response, response_data["response"], f"response from function: {res1.response} != response from api server: {response_data["response"]}")
- print("Response from function:", res1.response)
- print("Response from API:", response_data["response"])
+ print("Response from function:", res1)
+ print("Response from API:", response_data)
ray.get(rollout_controller.shutdown.remote(), timeout=300)
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
@@ -144,7 +145,7 @@ def test_lmdeploy_dataflow(self):
self.test_env
)
responses = ray.get(self.test_flow.run.remote(), timeout=300)
- finished_samples_count = sum(1 for data in responses for item in data if item.get("state") == "stop" or item.get("state") == "length")
+ finished_samples_count = sum(1 for data in responses for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length")
self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
ray.get(self.test_env.shutdown.remote(), timeout=300)
@@ -162,7 +163,7 @@ def test_lmdeploy_async_dataflow(self):
self.test_env
)
responses = ray.get(self.test_flow.run.remote(), timeout=300)
- finished_samples_count = sum(1 for data in responses for item in data if item.get("state") == "stop" or item.get("state") == "length")
+ finished_samples_count = sum(1 for data in responses for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length")
self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size)
ray.get(self.test_env.shutdown.remote())
diff --git a/xtuner/tools/plt.py b/xtuner/tools/plt.py
new file mode 100644
index 000000000..ba2231590
--- /dev/null
+++ b/xtuner/tools/plt.py
@@ -0,0 +1,134 @@
+import os
+import re
+import argparse
+import json
+import matplotlib.pyplot as plt
+
+def extract_xtuner_rewards(folder_path):
+ import os
+ import glob
+ from pathlib import Path
+ matching_files = list(Path(folder_path).glob('rollout_idx_*.jsonl'))
+ file_count = len(matching_files)
+ # print("matching_files: ", matching_files)
+ steps = []
+ rewards = []
+
+ file_count = len(matching_files)
+ for i in range(1, file_count+1):
+ file_path = os.path.join(folder_path, f'rollout_idx_{i}_trajectory.jsonl')
+ print(file_path)
+ reward = calculate_average_reward(file_path)
+ steps.append(i)
+ rewards.append(reward)
+ return steps, rewards
+
+def calculate_average_reward(file_path):
+ all_rewards = []
+ line_count = 0
+ try:
+ with open(file_path, 'r', encoding='utf-8') as f:
+ for line in f:
+ line_count += 1
+ if not line.strip():
+ continue
+ try:
+ data = json.loads(line)
+ if 'reward' in data and isinstance(data['reward'], list):
+ all_rewards.extend(data['reward'])
+ except json.JSONDecodeError:
+ print(f"警告: 第 {line_count} 行不是有效的JSON格式,已跳过。")
+
+ except FileNotFoundError:
+ return
+
+ total_reward_sum = sum(all_rewards)
+ total_reward_count = len(all_rewards)
+ average_reward = total_reward_sum / total_reward_count
+ return average_reward
+
+def extract_accuracy(folder_path):
+ infer_log_path = folder_path + "/rank0.log"
+ extract_accuracy_list = []
+ accuracy_step = []
+ with open(infer_log_path, 'r', encoding='utf-8') as f:
+ for line in f:
+ match = re.search(r"'accuracy':\s*([\d.]+)", line)
+ if match:
+ accuracy_value = float(match.group(1))
+ step_match = re.search(r"idx (.*?) scores", line)
+ idx_value = step_match.group(1).strip() if step_match else None
+ extract_accuracy_list.append(accuracy_value)
+ accuracy_step.append(idx_value)
+ return accuracy_step[1:], extract_accuracy_list[1:]
+
+def extract_entropy(folder_path):
+ train_log_path = folder_path + "/train_rank0.log"
+ entropy_list = []
+ with open(train_log_path, 'r', encoding='utf-8') as f:
+ for line in f:
+ if "entropy" in line:
+ entropy = float(line.split(":")[4].strip())
+ entropy_list.append(entropy)
+ return entropy_list
+
+def extract_grad_norm(folder_path):
+ train_log_path = folder_path + "/train_rank0.log"
+ grad_norm_list = []
+ with open(train_log_path, 'r', encoding='utf-8') as f:
+ for line in f:
+ if "grad_norm" in line:
+ print(line)
+ grad_norm = float(line.split("=")[4].strip())
+ grad_norm_list.append(grad_norm)
+ return grad_norm_list
+
+
+def plt_image(step, reward, entropy, grad_norm, loss, eval_step, eval_accuracy, output_path):
+ plt.figure(figsize=(14, 8))
+
+ # Plot each metric
+ if reward:
+ plt.plot(step, reward, marker='o', linestyle='-', markersize=4, label='Reward')
+ if grad_norm:
+ plt.plot(step, grad_norm, marker='x', linestyle=':', markersize=4, label='Gradient Norm')
+ if loss:
+ plt.plot(step, loss, marker='d', linestyle='-', markersize=4, label='Loss')
+ if entropy:
+ plt.plot(step, entropy, marker='s', linestyle='--', markersize=4, label='Entropy')
+ # if eval_accuracy:
+ # plt.scatter(eval_step, eval_accuracy, marker='o', color='r', label='EvalAccuracy', zorder=5)
+ # Add chart title and axis labels
+ plt.title('Training Metrics Over Steps', fontsize=16)
+ plt.xlabel('Step', fontsize=12)
+ plt.ylabel('Value', fontsize=12)
+
+ # Add a grid for better readability
+ plt.grid(True, which='both', linestyle='--', linewidth=0.5)
+
+ # Add a legend
+ plt.legend()
+
+ # Adjust layout automatically
+ plt.tight_layout()
+
+ # Save the chart to a file
+ plt.savefig(output_path)
+ print(f"Chart saved to: {output_path}")
+
+ # Display the chart
+ plt.show()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Plot rewards from XTuner log files.")
+ parser.add_argument("--log-dir-path", type=str, help="Path to the directory containing xtuner log files.")
+ args = parser.parse_args()
+
+ xtuner_steps, xtuner_reward = extract_xtuner_rewards(args.log_dir_path)
+ eval_step, eval_accuracy = extract_accuracy(args.log_dir_path)
+ entropy = extract_entropy(args.log_dir_path)
+ grad_norm = extract_grad_norm(args.log_dir_path)
+ save_path = os.path.join(args.log_dir_path, "xtuner_rl_metrics.png")
+ len = min(len(xtuner_steps), len(xtuner_reward), len(entropy), len(grad_norm))
+ # plt_image(xtuner_steps[:len], xtuner_reward[:len], entropy[:len], grad_norm[:len], None, eval_step, eval_accuracy, save_path)
+ plt_image(xtuner_steps[:len], xtuner_reward[:len], None, None, None, None, None, save_path)
diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py
new file mode 100644
index 000000000..ed00cff93
--- /dev/null
+++ b/xtuner/v1/data_proto/rl_data.py
@@ -0,0 +1,300 @@
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional
+
+import ray
+from cyclopts import Parameter
+from pydantic import BaseModel, ConfigDict, Field
+from ray import ObjectRef
+from typing_extensions import Annotated
+
+
+# ====================================
+# ====== DataFlow 数据流 ==============
+# ====================================
+
+
+class RLUIDItem(BaseModel):
+ """A unique identifier for tracking data items within the dataflow.
+
+ Attributes:
+ env (str): The environment name.
+ root_id (int): The root ID for grouping related data items.
+ action_id (int): The ID for a specific action in prompt.
+ observation_id (int): The ID for a specific observation in response.
+ version (int): The version number of the data item.
+ """
+
+ model_config = ConfigDict(extra="forbid")
+ env: str = ""
+ root_id: int = -1
+ action_id: int = -1
+ observation_id: int = -1
+ version: int = -1
+
+
+class RLDatasetItem(BaseModel):
+ """Represents the data structure output from the dataset.
+
+ Attributes:
+ messages (Optional[List[Dict[str, Any]]]): The message list for the prompt.
+ input_ids (Optional[List[int]]): The tokenized input IDs.
+ num_tokens (Optional[int]): The number of tokens in the input.
+ ability (Optional[str]): The ability or category of the data.
+ reward_model (Optional[Dict[str, Any]]): Data required by the reward model, like ground truth.
+ data_source (Optional[Dict[str, Any]]): The source of the data, used for weighting rewards.
+ extra_info (Dict[str, Any]): Additional user-defined information.
+ """
+
+ model_config = ConfigDict(extra="forbid")
+ messages: Optional[List[Dict[str, Any]]] = None
+ input_ids: Optional[List[int]] = None
+ num_tokens: Optional[int] = None
+ ability: Optional[str] = None
+ reward_model: Optional[Dict[str, Any]] = None
+ data_source: Optional[Dict[str, Any]] = None
+ extra_info: Dict[str, Any] = dict()
+
+
+class RLRolloutResponseItem(BaseModel):
+ """Represents the data structure output from the rollout process.
+
+ Attributes:
+ response (Optional[str]): The generated text response from the model.
+ response_ids (Optional[List[int]]): The token IDs of the generated response.
+ num_return_tokens (Optional[int]): The number of tokens in the response.
+ finish_reason (Optional[str]): The reason why the generation finished (e.g., 'stop', 'length').
+ logprobs (Optional[List[float]]): The log probabilities of the generated tokens.
+ extra_info (Dict[str, Any]): Additional user-defined information.
+ """
+
+ model_config = ConfigDict(extra="forbid")
+ response: Optional[str] = None
+ response_ids: Optional[List[int]] = None
+ num_return_tokens: Optional[int] = None
+ finish_reason: Optional[str] = None
+ logprobs: Optional[List[float]] = None
+ extra_info: Dict[str, Any] = dict()
+
+
+class RLJudgerResponseItem(BaseModel):
+ """Represents the data structure output from the judger.
+
+ Attributes:
+ uid (Optional[int]): A unique ID to identify which input the result corresponds to.
+ reward (Dict[str, Any]): A dictionary of reward scores, e.g., {"judger_type": reward_score, "weighted_scores": score}.
+ extra_info (Dict[str, Any]): Additional user-defined information.
+ """
+
+ model_config = ConfigDict(extra="forbid")
+ uid: Optional[int] = None
+ reward: Dict[str, Any] = dict()
+ extra_info: Dict[str, Any] = dict()
+
+
+class RLAgentDataItem(BaseModel):
+ # todo: define agent output data structure
+ model_config = ConfigDict(extra="forbid")
+ extra_info: Dict[str, Any] = dict()
+
+
+class RLEnvDataItem(BaseModel):
+ """Contains the internal data structures of the environment, stored as an
+ observation.
+
+ Attributes:
+ rollout (RLRolloutResponseItem): Data from the rollout stage.
+ judger (RLJudgerResponseItem): Data from the judger stage.
+ agent (RLAgentDataItem): Data from the agent stage.
+ extra_info (Dict[str, Any]): Additional user-defined information.
+ """
+
+ model_config = ConfigDict(extra="forbid")
+ rollout: RLRolloutResponseItem = RLRolloutResponseItem()
+ judger: RLJudgerResponseItem = RLJudgerResponseItem()
+ agent: RLAgentDataItem = RLAgentDataItem()
+ extra_info: Dict[str, Any] = dict()
+
+
+class RLExtraDataItem(BaseModel):
+ """Reserved for data that does not belong to a specific stage of the
+ dataflow.
+
+ Attributes:
+ retry_times (int): The number of times the data processing has been retried.
+ extra_info (Dict[str, Any]): Additional user-defined information.
+ """
+
+ model_config = ConfigDict(extra="forbid")
+ retry_times: int = 0
+ extra_info: Dict[str, Any] = dict()
+
+
+class RLDataFlowItem(BaseModel):
+ """The core data structure that flows through the dataflow and environment.
+
+ It encapsulates all information related to a single data point, including its
+ unique ID, the original data, environment outputs, and extra metadata.
+
+ Attributes:
+ uid (RLUIDItem): The unique identifier for the data item.
+ data (RLDatasetItem): The original data from the dataset.
+ env (RLEnvDataItem): The collected outputs from the environment stages.
+ extra_info (RLExtraDataItem): Additional reserved information.
+ """
+
+ model_config = ConfigDict(extra="forbid")
+ uid: RLUIDItem = RLUIDItem()
+ data: RLDatasetItem = RLDatasetItem()
+ env: RLEnvDataItem = RLEnvDataItem()
+ extra_info: RLExtraDataItem = RLExtraDataItem()
+
+
+def update_dataflow_item(group_data_items, target_key, target_value):
+ """Update a list of RLDataFlowItem objects by setting a nested attribute
+ for each item.
+
+ Args:
+ group_data_items (List[RLDataFlowItem]): List of data items to update.
+ target_key (str): Dot-separated path to the attribute to update (e.g., 'env.rollout.response').
+ target_value (List[Any]): List of values to set, one for each data item.
+
+ Returns:
+ List[RLDataFlowItem]: The updated list of data items.
+
+ Example:
+ >>> # Suppose you want to update the 'response' field in env.rollout for each item
+ >>> items = [RLDataFlowItem(), RLDataFlowItem()]
+ >>> responses = ["hello", "world"]
+ >>> update_dataflow_item(items, "env.rollout.response", responses)
+ # Now items[0].env.rollout.response == "hello", items[1].env.rollout.response == "world"
+ """
+ group_length = len(group_data_items)
+ assert group_length == len(target_value)
+
+ keys = target_key.split(".")
+ for i in range(group_length):
+ parent_obj = group_data_items[i]
+ for key in keys[:-1]:
+ parent_obj = getattr(parent_obj, key)
+ setattr(parent_obj, keys[-1], target_value[i])
+
+ return group_data_items
+
+
+# ==============================================
+# ====== Rollout API Server 数据流 ==============
+# ==============================================
+
+
+class SampleParams(BaseModel):
+ n: Annotated[int, Parameter(help="Number of samples to generate.")] = 1
+ top_k: Annotated[
+ int, Parameter(help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
+ ] = 0
+ top_p: Annotated[float, Parameter(help="The cumulative probability for nucleus sampling.")] = 1.0
+ temperature: Annotated[float, Parameter(help="The value used to module the next token probabilities.")] = 1.0
+ repetition_penalty: Annotated[float, Parameter(help="The parameter for repetition penalty.")] = 1.0
+ presence_penalty: Annotated[float, Parameter(help="The parameter for presence penalty.")] = 0.0
+ frequency_penalty: Annotated[float, Parameter(help="The parameter for frequency penalty.")] = 0.0
+ min_tokens: Annotated[int, Parameter(help="Minimum number of tokens to generate.")] = 0
+ max_tokens: Annotated[int, Parameter(help="Maximum number of tokens to generate.")] = 2048
+ stops: Annotated[List[str], Parameter(help="List of stop sequences.")] = []
+ stop_token_ids: Annotated[List[int], Parameter(help="List of stop token IDs.")] = []
+ skip_special_tokens: Annotated[bool, Parameter(help="Whether to skip special tokens.")] = True
+ do_sample: Annotated[bool, Parameter(help="Whether to sample or not.")] = True
+
+
+# 说明: 这里没定义API server情况数据格式,因为直接使用openai server的格式
+class RLRolloutRequestItem(BaseModel):
+ messages: List[Dict[str, Any]]
+ tools: List = Field(default_factory=list)
+ tool_choice: str = "auto"
+ sample_params: SampleParams = Field(default_factory=SampleParams)
+ extra_params: Dict[str, Any] = Field(default_factory=dict)
+
+
+# ==============================================
+# ====== ReplayBuffer 数据流 =====================
+# ==============================================
+
+
+@dataclass
+class ReplayMeta:
+ """ReplayMeta aggregates all versions of data related to a single prompt in
+ the replay buffer.
+
+ Attributes:
+ env (str): Name or identifier of the environment.
+ root_id (int): Identifier for grouping related prompts (e.g., for GRPO or multi-turn scenarios).
+ action_id (int): Unique identifier for the prompt. If the prompt changes (such as in a multi-turn scenario), a new action_id is assigned.
+ action_ref (ObjectRef): Ray object reference to the prompt data (corresponds to RLDatasetItem in RLDataFlowItem).
+ observation_ids (List[int]): IDs for different responses to the same prompt. Each response has a unique observation_id.
+ observation_refs (List[ObjectRef]): Ray object references to environment data for each observation (corresponds to RLEnvDataItem in RLDataFlowItem).
+ observation_versions (List[int]): Version numbers for each observation, supporting async rollout.
+ state (str): Overall state of the prompt (e.g., "paused" for partial rollout, or other rollout states).
+ extra_info (Dict[str, Any]): Additional metadata or information.
+ """
+
+ env: str = ""
+ root_id: int = 0
+ action_id: int = 0 # same prompt share the same action_id
+ action_ref: ObjectRef = None
+ observation_ids: List[int] = field(default_factory=list) # observation IDs for different versions
+ observation_refs: List[ObjectRef] = field(default_factory=list)
+ observation_versions: List[int] = field(default_factory=list) # reserved for async rollout
+ state: str = "" # overall state, e.g., for partial rollout
+ extra_info: Dict[str, Any] = field(default_factory=dict)
+
+
+def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> ReplayMeta:
+ assert len(grouped_dataitem) > 0
+
+ env_str = grouped_dataitem[0].uid.env
+ root_id = grouped_dataitem[0].uid.root_id
+ action_id = grouped_dataitem[0].uid.action_id
+ data = grouped_dataitem[0].data
+ observation_ids = []
+ observation_refs = []
+ observation_versions = []
+
+ group_states = []
+ for item in grouped_dataitem:
+ version = item.uid.version
+ observation_ids.append(item.uid.observation_id)
+ observation_refs.append(ray.put(item.env))
+ observation_versions.append(version)
+ group_states.append(item.env.rollout.finish_reason)
+
+ state_str = "paused" if "paused" in group_states else "returned"
+ replay_meta = ReplayMeta(
+ env=env_str,
+ root_id=root_id,
+ action_id=action_id,
+ action_ref=ray.put(data),
+ observation_ids=observation_ids,
+ observation_refs=observation_refs,
+ observation_versions=observation_versions,
+ state=state_str, # 指代一个prompt的整体状态,用于partial rollout
+ extra_info={},
+ )
+ return replay_meta
+
+
+def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta) -> List[RLDataFlowItem]:
+ env_str = replay_meta.env
+ root_id = replay_meta.root_id
+ action_id = replay_meta.action_id
+ data_ref = ray.get(replay_meta.action_ref)
+ group_data_item = []
+ for obs_id, obs_ref, version in zip(
+ replay_meta.observation_ids, replay_meta.observation_refs, replay_meta.observation_versions
+ ):
+ env_data = ray.get(obs_ref)
+ item = RLDataFlowItem(
+ uid=RLUIDItem(env=env_str, root_id=root_id, action_id=action_id, observation_id=obs_id, version=version),
+ data=data_ref,
+ env=env_data,
+ extra_info=RLExtraDataItem(),
+ )
+ group_data_item.append(item)
+ return group_data_item
diff --git a/xtuner/v1/datasets/data_item.py b/xtuner/v1/datasets/data_item.py
index 611e565ac..d621cec10 100644
--- a/xtuner/v1/datasets/data_item.py
+++ b/xtuner/v1/datasets/data_item.py
@@ -1,5 +1,3 @@
-from typing import Any, Dict, List
-
import torch
from typing_extensions import TypedDict
@@ -28,22 +26,3 @@ class QwenVL3DataItem(BaseMLLMDataItem, total=False):
pixel_values: torch.Tensor
image_grid_thw: torch.Tensor
position_ids: torch.Tensor
-
-
-class RLTextDataItem(CacheItem, total=False):
- env: str
- group_id: int
- prompt_id: int
- input_ids: list[int]
- messages: str | List[Dict[str, Any]]
- prompt: str
- data_source: dict | None # e.g., {"math" : "0.8", "code": "0.2"}
- ability: str | None # math, code
- reward_model: dict
- reward: float | None
- num_return_tokens: int | None
- response_ids: list[int] | None
- response_str: str | None
- state: str
- retry_times: int
- extra_info: dict
diff --git a/xtuner/v1/datasets/rl_tokenize_fn/rl_text_fn.py b/xtuner/v1/datasets/rl_tokenize_fn/rl_text_fn.py
index 58379ce68..053a1ca7f 100644
--- a/xtuner/v1/datasets/rl_tokenize_fn/rl_text_fn.py
+++ b/xtuner/v1/datasets/rl_tokenize_fn/rl_text_fn.py
@@ -3,7 +3,8 @@
from pydantic import BaseModel, ConfigDict
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
-from xtuner.v1.datasets.data_item import CacheItem, RLTextDataItem
+from xtuner.v1.data_proto.rl_data import RLDatasetItem
+from xtuner.v1.datasets.data_item import CacheItem
from xtuner.v1.utils import get_logger
from ..utils import CachableTokenizeFunction
@@ -13,7 +14,7 @@
# https://github.com/volcengine/verl/blob/main/verl/utils/dataset/rl_dataset.py
-class RLTextTokenizeFn(CachableTokenizeFunction[RLTextDataItem]):
+class RLTextTokenizeFn(CachableTokenizeFunction[RLDatasetItem]):
def __init__(self, tokenizer: PreTrainedTokenizer, max_length: int | None = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = (
@@ -23,7 +24,7 @@ def __init__(self, tokenizer: PreTrainedTokenizer, max_length: int | None = None
)
self.max_length = max_length
- def __call__(self, item: dict, **kwargs) -> RLTextDataItem | CacheItem:
+ def __call__(self, item: dict, **kwargs) -> RLDatasetItem | CacheItem:
"""example:
item = {
"data_source": data_source,
@@ -56,7 +57,6 @@ def __call__(self, item: dict, **kwargs) -> RLTextDataItem | CacheItem:
extra_info["raw_prompt"] = raw_prompt
rl_out_data = {
- # "input_ids": input_ids,
"messages": messages,
"num_tokens": num_tokens,
"reward_model": item["reward_model"],
diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py
index 715ece248..7fb326edf 100644
--- a/xtuner/v1/ray/config/worker.py
+++ b/xtuner/v1/ray/config/worker.py
@@ -107,6 +107,13 @@ class RolloutConfig(BaseModel):
help="Whether to enable cross-node communication for the rollout worker.",
),
] = False
+ rollout_max_batch_size: Annotated[
+ Optional[int],
+ Parameter(
+ group=infer_group,
+ help="Maximum batch size for the rollout worker. If not set, it will be determined automatically based on the model and GPU memory.",
+ ),
+ ] = None
tensor_parallel_size: Annotated[
int,
Parameter(
@@ -171,6 +178,13 @@ class RolloutConfig(BaseModel):
help="System prompt for the rollout worker.",
),
] = None
+ return_stop_tokens: Annotated[
+ bool,
+ Parameter(
+ group=infer_group,
+ help="Whether to return stop tokens in the rollout response.",
+ ),
+ ] = True
if __name__ == "__main__":
diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py
index b1a8d4da0..442a1ac47 100644
--- a/xtuner/v1/ray/dataflow/flow.py
+++ b/xtuner/v1/ray/dataflow/flow.py
@@ -7,7 +7,7 @@
from tqdm.auto import tqdm
from typing_extensions import Annotated
-from xtuner.v1.datasets.data_item import RLTextDataItem
+from xtuner.v1.data_proto.rl_data import RLDataFlowItem
from xtuner.v1.ray.environment import SingleTurnEnvironment
from xtuner.v1.ray.rollout.controller import SampleParams
from xtuner.v1.ray.utils import create_task
@@ -112,7 +112,7 @@ def get_train_dataset_length(self):
"""Gets the length of the training dataset from the replay buffer."""
return ray.get(self.replay_buffer.get_train_dataset_length.remote())
- async def worker_task(self, group_samples_for_retry: Optional[List[RLTextDataItem]] = None):
+ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowItem]] = None):
"""A single worker task to generate and process a group of samples.
This task performs the following steps:
@@ -122,43 +122,36 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLTextDataIte
4. Adds the filtered samples to the replay buffer.
Args:
- group_samples_for_retry (Optional[List[RLTextDataItem]]): A group
+ group_samples_for_retry (Optional[List[RLDataFlowItem]]): A group
of samples to retry if a previous attempt failed. Defaults to
None.
Returns:
- Optional[List[RLTextDataItem]]: The group of samples if the task
+ Optional[List[RLDataFlowItem]]: The group of samples if the task
fails and needs to be retried, otherwise None.
"""
group_samples = group_samples_for_retry
try:
+ # 该函数中所有的数据结构都是RLDataFlowItem
# step 1: sample
- if group_samples is None:
- group_samples = await self.replay_buffer.sample.remote( # type: ignore[attr-defined]
- self.env,
- self.config.enable_partial_rollout,
- self.config.prompt_repeat_k,
- )
- self.send_samples_count += 1
- self.logger.debug(
- f"Get 1 sample and dataflow have sent {self.send_samples_count} to rollout_controller"
- )
- else:
- self.logger.debug("Retrying the failed sample")
+ group_data_items = await self.replay_buffer.sample.remote( # type: ignore[attr-defined]
+ self.env,
+ self.config.enable_partial_rollout,
+ self.config.prompt_repeat_k,
+ )
+ self.send_samples_count += 1
+ self.logger.debug(f"Get 1 sample and dataflow have sent {self.send_samples_count} to rollout_controller")
# step 2: env generate
- group_samples = await self.env_controller.run.remote(group_samples, self.sample_params) # type: ignore[attr-defined]
+ group_data_items = await self.env_controller.run.remote(group_data_items, self.sample_params) # type: ignore[attr-defined]
# step 3: filter
- filtered_group_samples = await self.replay_buffer.post_processor.remote(group_samples) # type: ignore[attr-defined]
+ filtered_group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined]
# step 4: add to replay buffer
- await self.replay_buffer.add.remote(filtered_group_samples) # type: ignore[attr-defined]
-
+ await self.replay_buffer.add.remote(filtered_group_data_items) # type: ignore[attr-defined]
except Exception as e:
if group_samples is not None and len(group_samples) > 0:
self.logger.error(f"Worker task failed with exception: {e}. Returning meta for retry.", exc_info=True)
for sample in group_samples:
- if "retry_times" not in sample:
- sample["retry_times"] = 0
- sample["retry_times"] += 1
+ sample.extra_info.retry_times += 1
return group_samples
else:
self.logger.warning(f"Worker task failed with exception: {e}. No samples to return.")
@@ -247,7 +240,7 @@ async def run(
runner to collect a new batch of samples.
Returns:
- List[RLTextDataItem]: A list of collected training samples.
+ List[RLDataFlowItem]: A list of collected training samples.
"""
ray.get(self.env_controller.restart.remote()) # type: ignore[attr-defined]
self.send_samples_count = 0
diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py
index 13bf253ed..bcce954e6 100644
--- a/xtuner/v1/ray/dataflow/replay_buffer.py
+++ b/xtuner/v1/ray/dataflow/replay_buffer.py
@@ -1,21 +1,25 @@
-import copy
import itertools
-import uuid
from collections import defaultdict
-from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Union
from uuid import uuid4
import ray
from cyclopts import Parameter
from pydantic import BaseModel, ConfigDict
-from ray import ObjectRef
from typing_extensions import Annotated
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
+from xtuner.v1.data_proto.rl_data import (
+ ReplayMeta,
+ RLDataFlowItem,
+ RLDatasetItem,
+ RLExtraDataItem,
+ RLUIDItem,
+ mapping_dataitem_to_replaymeta,
+ mapping_replaymeta_to_dataitem,
+)
from xtuner.v1.datasets import build_dataloader, build_datasets
from xtuner.v1.datasets.config import DataloaderConfig
-from xtuner.v1.datasets.data_item import RLTextDataItem
from xtuner.v1.utils import get_logger
@@ -61,8 +65,8 @@ class ReplayBufferConfig(BaseModel):
dataset_cfg: Annotated[List, Parameter(help="The dataset object to sample initial prompts from.")]
dataloader_cfg: Annotated[
- DataloaderConfig, Parameter(help="The PyTorch DataLoader for iterating over the dataset.")
- ]
+ Optional[DataloaderConfig], Parameter(help="The PyTorch DataLoader for iterating over the dataset.")
+ ] = None
tokenizer: Annotated[
Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str],
@@ -82,23 +86,6 @@ class ReplayBufferConfig(BaseModel):
] = {}
-@dataclass
-class ReplayMeta:
- """A dataclass to store metadata for a single action-observation step in
- the replay buffer."""
-
- env: str = ""
- group_id: int = 0
- action_id: int = 0
- action_refs: List[ObjectRef] = field(default_factory=list) # multi-turn action
- observation_ids: List[int] = field(default_factory=list) # multi-turn action and partial rollout
- observation_refs: List[ObjectRef] = field(default_factory=list) # multi-turn action and partial rollout
- observation_versions: List[int] = field(default_factory=list) # partial rollout
- rewards: List[float] = field(default_factory=list)
- state: str = ""
- ground_truth: str = ""
-
-
class Sampler:
"""Sampler for drawing prompts from datasets or the replay buffer."""
@@ -121,7 +108,7 @@ def __init__(self, dataset, dataloader, tokenizer, storage):
)
self.storage = storage
- def sample_from_datasets(self, env: str, repeat_prompt_k: int) -> List[RLTextDataItem]:
+ def sample_from_datasets(self, env: str, repeat_prompt_k: int) -> List[RLDataFlowItem]:
"""Samples a new group of prompts from the original dataset.
Args:
@@ -129,42 +116,36 @@ def sample_from_datasets(self, env: str, repeat_prompt_k: int) -> List[RLTextDat
repeat_prompt_k (int): The number of times to repeat the prompt.
Returns:
- List[RLTextDataItem]: A list of data items for the data group contains repeat_prompt_k samples from same data.
+ List[RLDataFlowItem]: A list of data items for the data group contains repeat_prompt_k samples from same data.
"""
- group_id = uuid4().int
- group_samples: List[RLTextDataItem] = []
+ root_id = uuid4().int
+ action_id = uuid4().int
+ group_data_item: List[RLDataFlowItem] = [RLDataFlowItem() for _ in range(repeat_prompt_k)]
try:
data = next(self.train_dataloader_iter)[0]
except StopIteration:
self.train_dataloader_iter = iter(self.train_dataloader)
data = next(self.train_dataloader_iter)[0]
- for _ in range(repeat_prompt_k):
- prompt_id = uuid4().int
- data_item = copy.deepcopy(data)
- data_item["env"] = env
- data_item["group_id"] = group_id
- data_item["prompt_id"] = prompt_id
- data_item["retry_times"] = 0
- group_samples.append(data_item)
- return group_samples
- def sample_from_unfinished_buffer(self):
+ for data_item in group_data_item:
+ data_item.uid = RLUIDItem(
+ env=env,
+ root_id=root_id,
+ action_id=action_id,
+ observation_id=uuid4().int,
+ )
+ data_item.data = RLDatasetItem(**data)
+ data_item.extra_info = RLExtraDataItem(retry_times=0)
+ return group_data_item
+
+ def sample_from_unfinished_buffer(self) -> List[RLDataFlowItem]:
"""Samples a prompt from a partially completed (unfinished) rollout."""
- prompt_id = self.storage._rollout_states["unfinished"].pop(0)
- group_replay_meta = [self.storage._actions[action_id] for action_id in self.storage._prompt2actions[prompt_id]]
- group_samples = []
- for replay_meta in group_replay_meta:
- latest_prompt = ray.get(replay_meta.action_refs[-1])
- latest_observation = ray.get(replay_meta.observation_refs[-1]) if replay_meta.observation_refs else ""
- messages = [
- {"role": "user", "content": f"{latest_prompt}"},
- {"role": "assistant", "content": f"{latest_observation}"},
- ]
- data_item = self.storage.replaymeta2dataitem(replay_meta, messages=messages)
- group_samples.append(data_item)
+ action_id = self.storage._paused.pop(0)
+ replay_meta = self.storage._actions[action_id]
+ group_samples = mapping_replaymeta_to_dataitem(replay_meta)
return group_samples
- def sample(self, env: str, enable_partial_rollout: int, prompt_repeat_k: int) -> List[RLTextDataItem]:
+ def sample(self, env: str, enable_partial_rollout: int, prompt_repeat_k: int) -> List[RLDataFlowItem]:
"""Selects a sampling strategy and returns a group of samples.
It decides whether to sample from the unfinished buffer (for partial
@@ -176,20 +157,15 @@ def sample(self, env: str, enable_partial_rollout: int, prompt_repeat_k: int) ->
prompt_repeat_k (int): Number of times to repeat the prompt.
Returns:
- List[RLTextDataItem]: A list of sampled data items.
+ List[RLDataFlowItem]: A list of sampled data items.
"""
- if (
- enable_partial_rollout
- and "unfinished" in self.storage._rollout_states
- and len(self.storage._rollout_states["unfinished"]) > 0
- ):
+ if enable_partial_rollout > 0 and len(self.storage._paused) > 0:
return self.sample_from_unfinished_buffer()
else:
- # note: Sample grouped sample at once. They share the same prompt and
- # prompt id but different action_id.
+ # note: Sample grouped sample at once. They share the same action_id
return self.sample_from_datasets(env, prompt_repeat_k)
- def resume(self, num):
+ def resume(self, num: int) -> None:
self.train_dataloader_iter = itertools.islice(self.train_dataloader, num, None)
@@ -198,110 +174,56 @@ class ReplayBufferStorage:
def __init__(self):
"""Initializes the data structures for storing replay data."""
- self._states: Dict[str, List[int]] = defaultdict(list) # str: [observation_id, observation_id, ...]
- self._rollout_states: Dict[str, List[int]] = defaultdict(
- list
- ) # str: [group_id, group_id, ...], designed for partial rollout
+ self._paused: List[int] = [] # List of paused action_id,
+ self._returned: List[int] = [] # List of returned action_id,
self._actions: Dict[int, ReplayMeta] = {} # action_id: ReplayMeta
+ self._root2actions: Dict[int, List[int]] = defaultdict(
+ list
+ ) # root_id: [action_id, action_id, ...], designed for grpo
self._observations: Dict[int, ReplayMeta] = {} # observation_id: ReplayMeta
- self._prompt2actions: Dict[int, List[int]] = defaultdict(list) # group_id: [action_id, action_id, ...]
+ self._observations2states: Dict[int, str] = {} # observation_id: state_str
+ self._states: Dict[str, List[int]] = defaultdict(list) # str: [observation_id, observation_id, ...]
self._action2observations: Dict[int, List[int]] = defaultdict(
list
) # action_id: [observation_id, observation_id, ...]
- self._observations2states: Dict[int, str] = {} # observation_id: state_str
self.logger = get_logger()
- def replaymeta2dataitem(self, replay_meta: ReplayMeta, messages=None, input_ids=None) -> RLTextDataItem:
- """Converts a ReplayMeta object to an RLTextDataItem.
-
- Args:
- replay_meta: The ReplayMeta object.
- prompt_str (Optional[str]): The prompt string. If None, it's
- retrieved from the replay_meta.
- input_ids (Optional[list]): The input IDs. Defaults to an empty list.
-
- Returns:
- RLTextDataItem: The converted data item.
- """
- messages = messages or (
- ray.get(replay_meta.action_refs[-1])
- if replay_meta.action_refs and len(replay_meta.action_refs) > 0
- else ""
- )
- input_ids = input_ids or []
- num_tokens = len(input_ids)
- response_str = (
- ray.get(replay_meta.observation_refs[0])
- if replay_meta.observation_refs and len(replay_meta.observation_refs) > 0
- else ""
- )
- return RLTextDataItem(
- env=replay_meta.env,
- group_id=replay_meta.group_id,
- prompt_id=replay_meta.action_id,
- messages=messages,
- input_ids=input_ids,
- num_tokens=num_tokens,
- response_str=response_str,
- reward_model={"ground_truth": replay_meta.ground_truth},
- reward=replay_meta.rewards[-1] if replay_meta.rewards and len(replay_meta.rewards) > 0 else None,
- state=replay_meta.state,
- )
-
- def dataitem2replaymeta(self, data_item: RLTextDataItem) -> ReplayMeta:
- """Converts an RLTextDataItem to a ReplayMeta object.
-
- Args:
- data_item: The RLTextDataItem to convert.
-
- Returns:
- ReplayMeta: The converted metadata object.
- """
- return ReplayMeta(
- env=data_item["env"],
- group_id=data_item["group_id"],
- action_id=data_item["prompt_id"],
- action_refs=[ray.put(data_item["messages"])] if "messages" in data_item else [],
- observation_ids=[uuid.uuid4().int],
- observation_refs=[ray.put(data_item["response_str"])] if "response_str" in data_item else [],
- observation_versions=[1],
- state=data_item["state"] if "state" in data_item else "",
- ground_truth=data_item["reward_model"]["ground_truth"],
- rewards=[data_item["reward"]] if "reward" in data_item and data_item["reward"] is not None else [],
- )
-
- def add(self, grouped_dataitem: List[RLTextDataItem]):
+ def add(self, grouped_dataitem: List[RLDataFlowItem]):
"""Adds a group of data items to the storage.
Args:
- grouped_dataitem (List[RLTextDataItem]): A list of data items
+ grouped_dataitem (List[RLDataFlowItem]): A list of data items
belonging to the same group.
"""
if len(grouped_dataitem) == 0:
return
- rollout_state = (
- "unfinished" if any(data_item["state"] == "unfinished" for data_item in grouped_dataitem) else "finished"
- )
- group_id = grouped_dataitem[0]["group_id"]
- self._rollout_states[rollout_state].append(group_id)
-
- for data_item in grouped_dataitem:
- replay_meta = self.dataitem2replaymeta(data_item)
- group_id = replay_meta.group_id
- action_id = replay_meta.action_id
- if action_id not in self._prompt2actions[group_id]:
- self._prompt2actions[group_id].append(action_id)
- self._actions[action_id] = replay_meta
- for observation_id in replay_meta.observation_ids:
- if observation_id not in self._states[replay_meta.state]:
- self._states[replay_meta.state].append(observation_id)
- if observation_id not in self._action2observations[action_id]:
- self._action2observations[action_id].append(observation_id)
- self._observations[observation_id] = replay_meta
- self._observations2states[observation_id] = replay_meta.state
-
- def get(self, global_batch_size: int) -> List[List[RLTextDataItem]]:
+ replay_meta = mapping_dataitem_to_replaymeta(grouped_dataitem)
+ root_id = replay_meta.root_id
+ action_id = replay_meta.action_id
+ state_str = replay_meta.state
+
+ # Here, partial rollout is handled based on whether finish_reason is "paused".
+ # The logic for "paused" is user-defined, indicating that this data was
+ # interrupted before inference was completed. Other states are returned
+ # by the inference engine.
+ if state_str == "paused":
+ self._paused.append(action_id)
+ elif state_str == "returned":
+ self._returned.append(action_id)
+
+ # action
+ self._root2actions[root_id].append(action_id)
+ self._actions[action_id] = replay_meta
+
+ # observation
+ for observation_id in replay_meta.observation_ids:
+ self._action2observations[action_id].append(observation_id)
+ self._observations[observation_id] = replay_meta
+ self._observations2states[observation_id] = replay_meta.state
+ self._states[replay_meta.state].append(observation_id)
+
+ def get(self, global_batch_size: int) -> List[List[RLDataFlowItem]]:
"""Retrieves a batch of finished sample groups from the buffer.
Args:
@@ -312,34 +234,46 @@ def get(self, global_batch_size: int) -> List[List[RLTextDataItem]]:
to meet the `global_batch_size`.
Returns:
- List[List[RLTextDataItem]]: A list of sample groups. Each inner
+ List[List[RLDataFlowItem]]: A list of sample groups. Each inner
list contains a group of data items that were generated from the
same initial prompt, repeated `repeat_prompt_k` times.
"""
samples = []
- if len(self._rollout_states["finished"]) < global_batch_size:
+ if len(self._returned) < global_batch_size:
raise ValueError("Not enough finished samples in replay buffer")
+ return []
else:
- target_finished_list = self._rollout_states["finished"][:global_batch_size]
- remain_finished_list = self._rollout_states["finished"][global_batch_size:]
- for group_id in target_finished_list:
- group_replay_meta = [self._actions[action_id] for action_id in self._prompt2actions[group_id]]
- group_samples = [self.replaymeta2dataitem(replay_meta) for replay_meta in group_replay_meta]
+ target_finished_list = self._returned[:global_batch_size]
+ remain_finished_list = self._returned[global_batch_size:]
+ for action_id in target_finished_list:
+ replay_meta = self._actions[action_id]
+ # todo: add an unified state management
+ replay_meta.state = "history"
+ group_samples = mapping_replaymeta_to_dataitem(self._actions[action_id])
samples.append(group_samples)
- self._rollout_states["finished"] = remain_finished_list
+ self._returned = remain_finished_list
return samples
+ def get_finished_samples(self):
+ """Returns the number of finished sample groups."""
+ return len(self._returned)
+
+ def get_unfinished_samples(self):
+ """Returns the number of unfinished sample groups."""
+ return len(self._paused)
+
+ def get_prompt_num(self):
+ return len(self._action2observations)
+
def print(self):
- finished_count = len(self._rollout_states["finished"])
- unfinished_count = len(self._rollout_states["unfinished"])
- group_count = len(self._prompt2actions)
+ rollout_finished_count = len(self._returned)
+ rollout_paused_count = len(self._paused)
action_count = len(self._actions)
observation_count = len(self._observations)
log_message = (
"ReplayBufferStorage states:\n"
- f" - Rollout States: Finished={finished_count}, Unfinished={unfinished_count}\n"
- f" - Sent Grouped Samples: {group_count}\n"
+ f" - Rollout States: Returned={rollout_finished_count}, Paused={rollout_paused_count}\n"
f" - History Actions: {action_count}\n"
f" - History Observations: {observation_count}"
)
@@ -359,40 +293,13 @@ def dump(self, file_path: str):
self.logger.info(f"Starting to dump ReplayBufferStorage state to {file_path}...")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
- # Deepcopy the state to avoid modifying the live buffer and to resolve
- # ObjectRefs in-place.
- actions_copy = copy.deepcopy(self._actions)
- # observations_copy = copy.deepcopy(self._observations)
- self.logger.info("Resolving ObjectRefs. This may take time and memory...")
-
- for replay_meta in actions_copy.values():
- if replay_meta.action_refs and all(isinstance(ref, ray.ObjectRef) for ref in replay_meta.action_refs):
- actions_list = []
- for ref in replay_meta.action_refs:
- actions_list.append(ray.get(ref))
- replay_meta.action_refs = actions_list
- if replay_meta.observation_refs and all(
- isinstance(ref, ray.ObjectRef) for ref in replay_meta.observation_refs
- ):
- observations_list = []
- for ref in replay_meta.observation_refs:
- observations_list.append(ray.get(ref))
- replay_meta.observation_refs = observations_list
-
- # Since _observations points to the same ReplayMeta objects as _actions,
- # we can reconstruct it on resume.
- state_to_dump = {
- "_states": self._states,
- "_rollout_states": self._rollout_states,
- "_actions": actions_copy,
- "_observations": self._observations,
- "_prompt2actions": self._prompt2actions,
- "_action2observations": self._action2observations,
- "_observations2states": self._observations2states,
- }
+ all_data_items = []
+ for replay_meta in self._actions.values():
+ group_data_items = mapping_replaymeta_to_dataitem(replay_meta)
+ all_data_items.append(group_data_items)
with open(file_path, "wb") as f:
- pickle.dump(state_to_dump, f)
+ pickle.dump(all_data_items, f)
self.logger.info(f"ReplayBufferStorage state dumped to {file_path}")
def resume(self, file_path: str):
@@ -411,45 +318,29 @@ def resume(self, file_path: str):
return
with open(file_path, "rb") as f:
- state_to_load = pickle.load(f)
-
- # Restore all components
- self._states = state_to_load["_states"]
- self._rollout_states = state_to_load["_rollout_states"]
- actions = state_to_load["_actions"]
- self._actions = {}
- self._observations = {}
- for action_id, replay_meta in actions.items():
- action_refs_list = []
- for action in replay_meta.action_refs:
- action_refs_list.append(ray.put(action))
- replay_meta.action_refs = action_refs_list
- observe_refs_list = []
- for observe in replay_meta.observation_refs:
- observe_refs_list.append(ray.put(observe))
- replay_meta.observation_refs = observe_refs_list
+ all_data_items = pickle.load(f)
+
+ for group_data_items in all_data_items:
+ replay_meta = mapping_dataitem_to_replaymeta(group_data_items)
+ root_id = replay_meta.root_id
+ action_id = replay_meta.action_id
+ state_str = replay_meta.state
+ if state_str == "paused":
+ self._paused.append(action_id)
+ elif state_str == "returned":
+ self._returned.append(action_id)
+ self._root2actions[root_id].append(action_id)
self._actions[action_id] = replay_meta
- for observ_id in replay_meta.observation_ids:
- self._observations[observ_id] = replay_meta
- self._prompt2actions = state_to_load["_prompt2actions"]
- self._action2observations = state_to_load["_action2observations"]
- self._observations2states = state_to_load["_observations2states"]
+ for observation_id in replay_meta.observation_ids:
+ self._action2observations[action_id].append(observation_id)
+ self._observations[observation_id] = replay_meta
+ self._observations2states[observation_id] = replay_meta.state
+ self._states[replay_meta.state].append(observation_id)
self.logger.info(f"ReplayBufferStorage state successfully resumed from {file_path}")
self.print()
- def get_finished_samples(self):
- """Returns the number of finished sample groups."""
- return len(self._rollout_states["finished"])
-
- def get_unfinished_samples(self):
- """Returns the number of unfinished sample groups."""
- return len(self._rollout_states["unfinished"])
-
- def get_prompt_num(self):
- return len(self._rollout_states)
-
@ray.remote
class ReplayBuffer:
@@ -469,8 +360,15 @@ def __init__(
self.tokenizer = config.tokenizer
self.datasets = build_datasets(config.dataset_cfg, self.tokenizer)
+ if config.dataloader_cfg is not None:
+ self.dataloader_cfg = config.dataloader_cfg
+ else:
+ self.dataloader_cfg = DataloaderConfig(
+ collator="fake_collator",
+ pack_level="none",
+ )
self.dataloader = build_dataloader(
- dataloader_config=config.dataloader_cfg,
+ dataloader_config=self.dataloader_cfg,
datasets=self.datasets,
global_batch_size=1,
micro_batch_size=1,
@@ -503,7 +401,7 @@ def post_processor(self, group_samples):
return group_samples
return group_samples
- def sample(self, env, enable_partial_rollout: int, prompt_repeat_k: int):
+ def sample(self, env, enable_partial_rollout: int, prompt_repeat_k: int) -> List[RLDataFlowItem]:
"""Samples a batch of experiences from the replay buffer.
Args:
@@ -530,11 +428,11 @@ def get_samples(
"""
return self.storage.get(global_batch_size)
- def add(self, grouped_dataitem: List[RLTextDataItem]):
+ def add(self, grouped_dataitem: List[RLDataFlowItem]):
"""Adds a group of data items to the replay buffer storage.
Args:
- grouped_dataitem (List[RLTextDataItem]): A list of data items
+ grouped_dataitem (List[RLDataFlowItem]): A list of data items
from the same group.
"""
self.storage.add(grouped_dataitem)
diff --git a/xtuner/v1/ray/environment/base_env.py b/xtuner/v1/ray/environment/base_env.py
index 61d968d87..e9274ca1e 100644
--- a/xtuner/v1/ray/environment/base_env.py
+++ b/xtuner/v1/ray/environment/base_env.py
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
-from typing import Any, List, Union
+from typing import Any, List
+
+from xtuner.v1.data_proto.rl_data import RLDataFlowItem
class BaseEnvironment(ABC):
@@ -85,7 +87,7 @@ def init_judger_controller(self, placement_group: Any, judger_cfg: Any):
return judger_controller
@abstractmethod
- async def generate(self, data: Union[list, Any, List[Any]], sample_params: Any) -> Union[list, Any, List[Any]]:
+ async def generate(self, data: List[RLDataFlowItem], sample_params: Any) -> List[RLDataFlowItem]:
"""Generates responses from the model for the given data using the
inference engine. This method is primarily used for single-step
inference.
@@ -100,7 +102,7 @@ async def generate(self, data: Union[list, Any, List[Any]], sample_params: Any)
pass
@abstractmethod
- async def run(self, data: Union[list, Any, List[Any]], sample_params: Any) -> Union[list, Any, List[Any]]:
+ async def run(self, data: List[RLDataFlowItem], sample_params: Any) -> List[RLDataFlowItem]:
"""Executes a full cycle of generation and interpretation, such as
generating a response and then evaluating it with a judger. This method
can be extended to support complex interactions like multi-turn
diff --git a/xtuner/v1/ray/environment/single_turn_env.py b/xtuner/v1/ray/environment/single_turn_env.py
index d5737abbf..adba5dba6 100644
--- a/xtuner/v1/ray/environment/single_turn_env.py
+++ b/xtuner/v1/ray/environment/single_turn_env.py
@@ -3,8 +3,9 @@
import ray
-from xtuner.v1.datasets.data_item import RLTextDataItem
+from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLJudgerResponseItem, update_dataflow_item
from xtuner.v1.ray.environment.base_env import BaseEnvironment
+from xtuner.v1.utils import get_logger
@ray.remote
@@ -24,17 +25,18 @@ class SingleTurnEnvironment(BaseEnvironment):
def __init__(self, environment: str, placement_group, rollout_cfg=None, judger_cfg=None):
super().__init__(environment, placement_group, rollout_cfg, judger_cfg)
+ self.logger = get_logger(__name__)
- async def generate(self, group_samples: List[RLTextDataItem], sample_params: None) -> List[RLTextDataItem]:
+ async def generate(self, group_data_items: List[RLDataFlowItem], sample_params: None) -> List[RLDataFlowItem]:
"""Generate responses for a batch of RLTextDataItem using the rollout
controller.
- Each item in `group_samples` will be sent to the rollout controller for response generation
+ Each item in `group_data_items` will be sent to the rollout controller for response generation
with the provided sampling parameters. The generated response string and state will be
added to each RLTextDataItem in-place as `response_str` and `state` fields.
Args:
- group_samples (List[RLTextDataItem]):
+ group_data_items (List[RLTextDataItem]):
A list of RLTextDataItem objects containing the prompts/messages for generation.
sample_params: Sampling parameters for the generation process. The type should match
the rollout controller's expected sampling parameter type (e.g., SampleParams or dict).
@@ -45,18 +47,17 @@ async def generate(self, group_samples: List[RLTextDataItem], sample_params: Non
and state from the rollout controller.
"""
if self.rollout_controller:
+ # 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦
+ # 每个模块返回独立的data item, 在env中进行更新
response_future = [
- self.rollout_controller.rollout.remote(prompt=sample["messages"], sample_params=sample_params)
- for sample in group_samples
+ self.rollout_controller.rollout.remote(prompt=sample.data.messages, sample_params=sample_params)
+ for sample in group_data_items
]
- response = await asyncio.gather(*response_future)
- for i in range(len(group_samples)):
- group_samples[i]["response_str"] = response[i].response
- group_samples[i]["state"] = response[i].finish_reason
+ rollout_responses = await asyncio.gather(*response_future) # RLRolloutResponseItem
+ group_data_items = update_dataflow_item(group_data_items, "env.rollout", rollout_responses)
+ return group_data_items
- return group_samples
-
- async def run(self, group_samples: List[RLTextDataItem], sample_params: None) -> List[RLTextDataItem]:
+ async def run(self, group_data_items: List[RLDataFlowItem], sample_params: None) -> List[RLDataFlowItem]:
"""Runs a full generation and judger cycle.
This method first generates responses using the `generate` method and then,
@@ -72,7 +73,8 @@ async def run(self, group_samples: List[RLTextDataItem], sample_params: None) ->
The data enriched with generated responses and evaluation results.
The format of the return value matches the format of the input `data`.
"""
- group_samples = await self.generate(group_samples, sample_params) # type: ignore[assignment]
+ group_data_items = await self.generate(group_data_items, sample_params) # type: ignore[assignment]
if self.judger_controller:
- group_samples = await self.judger_controller.run.remote(group_samples)
- return group_samples
+ judger_responses: RLJudgerResponseItem = await self.judger_controller.run.remote(group_data_items)
+ group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_responses)
+ return group_data_items
diff --git a/xtuner/v1/ray/evaluator.py b/xtuner/v1/ray/evaluator.py
index 1130ef49d..46951b718 100644
--- a/xtuner/v1/ray/evaluator.py
+++ b/xtuner/v1/ray/evaluator.py
@@ -8,9 +8,9 @@
from typing_extensions import Annotated
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
+from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLDatasetItem
from xtuner.v1.datasets import build_datasets
from xtuner.v1.datasets.config import DatasetConfigList
-from xtuner.v1.datasets.data_item import RLTextDataItem
from xtuner.v1.ray.environment import BaseEnvironment
from xtuner.v1.ray.rollout import SampleParams
from xtuner.v1.ray.utils import create_task
@@ -116,7 +116,7 @@ def __init__(self, config: EvaluatorConfig, env_controller: BaseEnvironment):
)
self.dataloader = iter(self.dataset)
self.env_controller = env_controller
- self.return_list: List[RLTextDataItem] = []
+ self.return_list: List[RLDataFlowItem] = []
if self.config.eval_sample_ratio > 0:
self.eval_batch_size = int(len(self.dataset) * self.config.eval_sample_ratio)
elif self.config.eval_sample_num > 0:
@@ -135,14 +135,14 @@ def default_compute_metric(self, samples):
Calculates accuracy based on whether the reward is positive.
Args:
- samples (list): A list of RLTextDataItem samples.
+ samples (list): A list of RLDataFlowItem samples.
Returns:
dict: A dictionary containing the accuracy score.
"""
- return {"accuracy": sum(s["reward"] > 0 for s in samples) / len(samples)}
+ return {"accuracy": sum(s.env.judger.reward["weighted_reward"] > 0 for s in samples) / len(samples)}
- async def eval_worker_task(self, sample: RLTextDataItem):
+ async def eval_worker_task(self, sample: RLDataFlowItem):
"""A single worker task to evaluate one sample.
This task calls the environment controller to run the model on a
@@ -150,21 +150,19 @@ async def eval_worker_task(self, sample: RLTextDataItem):
retry count.
Args:
- sample (RLTextDataItem): The data item to evaluate.
+ sample (RLDataFlowItem): The data item to evaluate.
Returns:
- RLTextDataItem or None: The sample with retry information if it
+ RLDataFlowItem or None: The sample with retry information if it
failed, or None if it succeeded or failed without a sample.
"""
try:
# note: In the evaluator, we convert the input sample to a list to adapt to the input format of single_turn_env
- samples = await self.env_controller.run.remote([sample], self.sample_params) # type: ignore[attr-defined]
- self.return_list.append(samples[0])
+ group_sample = await self.env_controller.run.remote([sample], self.sample_params) # type: ignore[attr-defined]
+ self.return_list.append(group_sample[0])
except Exception as e:
- self.logger.error(f"Worker task failed with exception: {e}. Returning meta for retry.", exc_info=True)
- if "retry_times" not in sample:
- sample["retry_times"] = 0
- sample["retry_times"] += 1
+ self.logger.error(f"Worker task failed with exception: {e}. Returning meta for retry.")
+ sample.extra_info.retry_times += 1
return sample
async def concurrent_eval_task_runner(self):
@@ -190,10 +188,11 @@ async def concurrent_eval_task_runner(self):
if len(self.return_list) + len(waiting_tasks) >= self.eval_batch_size:
break
try:
- sample = next(self.dataloader)
+ data = next(self.dataloader)
except StopIteration:
break
- task = create_task(self.eval_worker_task(sample))
+ data_item = RLDataFlowItem(data=RLDatasetItem(**data))
+ task = create_task(self.eval_worker_task(data_item))
waiting_tasks.add(task)
assert len(waiting_tasks) > 0
@@ -203,12 +202,13 @@ async def concurrent_eval_task_runner(self):
for task in done_tasks:
result = task.result()
if result is not None:
- if result["retry_times"] < self.config.max_retry_times:
+ if result.extra_info.retry_times < self.config.max_retry_times:
# If the retry count is less than max_retry_times, retry the task
retry_task = create_task(self.eval_worker_task(result))
pending_tasks.add(retry_task)
else:
- self.logger.error(f"Max retry reached for {result['prompt_id']}. Not retrying.")
+ self.logger.error(f"Max retry reached for {result.uid.action_id}. Not retrying.")
+
waiting_tasks = pending_tasks
pbar.n = len(self.return_list)
diff --git a/xtuner/v1/ray/judger/controller.py b/xtuner/v1/ray/judger/controller.py
index 074e007cd..62ff32c71 100644
--- a/xtuner/v1/ray/judger/controller.py
+++ b/xtuner/v1/ray/judger/controller.py
@@ -1,12 +1,12 @@
import asyncio
-from typing import Dict, List
+from typing import List
import ray
from cyclopts import Parameter
from pydantic import BaseModel
from typing_extensions import Annotated
-from xtuner.v1.datasets.data_item import RLTextDataItem
+from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLJudgerResponseItem
from xtuner.v1.utils import get_logger
from .native import NativeJudger
@@ -62,9 +62,9 @@ class JudgerConfig(BaseModel):
bool, Parameter(help="Whether to enable batch reward calculation for multiple samples at once.")
] = False
reward_judger_configs: Annotated[
- Dict[str, BaseModel],
+ List[BaseModel],
Parameter(help="A custom Python function for computing reward given model output and label."),
- ] = {}
+ ] = []
@ray.remote
@@ -83,17 +83,17 @@ def __init__(self, judger_config: JudgerConfig, placement_group=None):
# note: placement_group is used to control the placement of Ray tasks.
# It will be implemented when gpu judger is needed
self.placement_group = placement_group
- self.reward_judger = {}
- for name, config in self.judger_config.reward_judger_configs.items():
- self.reward_judger[name] = config.build()
+ self.reward_judger = []
+ for config in self.judger_config.reward_judger_configs:
+ self.reward_judger.append(config.build())
+
self.logger = get_logger()
async def _call_custom_reward_judger(
self,
- active_judgers: Dict[str, NativeJudger],
- responses: List[str],
- labels: List[str],
- ) -> Dict[str, List[float]]:
+ active_judgers: List[NativeJudger],
+ group_data_item: List[RLDataFlowItem],
+ ) -> List[RLJudgerResponseItem]:
"""Call custom reward judgers to calculate rewards.
Args:
@@ -106,36 +106,44 @@ async def _call_custom_reward_judger(
Dict[str, List[float]]: A dictionary where keys are judger names
and values are lists of calculated rewards for each sample.
"""
- group_size = len(responses)
+ results = []
if self.judger_config.enable_batch_reward:
- tasks = {name: judger.judge(responses, labels) for name, judger in active_judgers.items()}
- results = await asyncio.gather(*tasks.values())
- return dict(zip(tasks.keys(), [results] * group_size))
-
+ tasks = [judger.judge(group_data_item) for judger in active_judgers]
+ results = await asyncio.gather(*tasks)
else:
tasks_per_sample = [
- [(name, judger.judge(responses[i], labels[i])) for name, judger in active_judgers.items()]
- for i in range(len(responses))
+ [(judger.judge([group_data_item[i]])) for judger in active_judgers]
+ for i in range(len(group_data_item))
]
flat_tasks_with_names = [task for sample_tasks in tasks_per_sample for task in sample_tasks]
- coroutines = [item[1] for item in flat_tasks_with_names]
-
- flat_results = await asyncio.gather(*coroutines)
- final_rewards: Dict[str, List[float]] = {
- name: [] for name in active_judgers
- } # name: [sample1, sample2, ...]
- active_reward_size = len(active_judgers)
- for name_index in range(active_reward_size):
- reward_list = []
- for index in range(group_size):
- reward_list.append(flat_results[index * active_reward_size + name_index])
- final_rewards[list(active_judgers.keys())[name_index]] = reward_list
- return final_rewards
+ results = await asyncio.gather(*flat_tasks_with_names)
+
+ import collections.abc
+
+ def flatten(results):
+ for item in results:
+ if isinstance(item, collections.abc.Iterable) and not isinstance(item, (str, bytes, dict)):
+ yield from item
+ else:
+ yield item
+
+ flat_results = list(flatten(results))
+ assert len(flat_results) == len(group_data_item) * len(active_judgers), (
+ f"Expected {len(group_data_item) * len(active_judgers)} results, but got {len(flat_results)}"
+ )
+ # 将不同Judger的RLJudgerResponseItem进行组装
+ uid_list = [item.uid.observation_id for item in group_data_item]
+ judger_response_items_dict = {uid: RLJudgerResponseItem(uid=uid) for uid in uid_list}
+ for result in flat_results:
+ return_uid = result.uid
+ judger_response_items_dict[return_uid].reward.update(result.reward)
+ judger_response_items_dict[return_uid].extra_info.update(result.extra_info)
+ return list(judger_response_items_dict.values())
async def run(
- self, group_data_item: RLTextDataItem | List[RLTextDataItem]
- ) -> RLTextDataItem | List[RLTextDataItem]:
+ self, group_data_item: RLDataFlowItem | List[RLDataFlowItem]
+ ) -> RLJudgerResponseItem | List[RLJudgerResponseItem]:
"""Run the judging process for a group of data items.
Args:
@@ -145,32 +153,28 @@ async def run(
Returns:
List[float]: A list of final calculated rewards for each data item.
"""
- if not group_data_item:
- return []
- input_list = True
- if not isinstance(group_data_item, List):
+ input_type_is_list = True
+ if not isinstance(group_data_item, list):
+ input_type_is_list = False
group_data_item = [group_data_item]
- input_list = False
- batch_responses = [item["response_str"] or "" for item in group_data_item]
- batch_labels = [item["reward_model"]["ground_truth"] for item in group_data_item]
- data_source = group_data_item[0]["data_source"]
- assert data_source, "No data source found for the given datsetes"
-
- active_reward_judger = {name: func for name, func in self.reward_judger.items() if name in data_source}
- assert active_reward_judger, f"No active reward judger found for the given data source {data_source}."
-
- rewards_by_name = await self._call_custom_reward_judger(active_reward_judger, batch_responses, batch_labels)
- num_samples = len(group_data_item)
- final_rewards = [0.0] * num_samples
-
- for i in range(num_samples):
- for name, scores in rewards_by_name.items():
- weight = data_source.get(name, 1.0)
- final_rewards[i] += scores[i] * weight
-
- assert len(final_rewards) == num_samples
- for i, item in enumerate(group_data_item):
- item["reward"] = final_rewards[i]
- if not input_list:
- return group_data_item[0]
- return group_data_item
+ # Assume all data have the same data_source
+ data_source = group_data_item[0].data.data_source
+ assert data_source, "No data source found for the given datasets"
+
+ judger_names = [judger.judger_name for judger in self.reward_judger]
+ active_reward_judger = [func for func in self.reward_judger if func.judger_name in data_source]
+ assert active_reward_judger, (
+ f"No active reward judger in {judger_names} found for the given data source {data_source}."
+ )
+
+ judger_response_item = await self._call_custom_reward_judger(active_reward_judger, group_data_item)
+ for item in judger_response_item:
+ final_reward = 0
+ for name, weight in data_source.items():
+ if name in item.reward:
+ final_reward += item.reward[name] * weight
+ item.reward["weighted_reward"] = final_reward
+
+ if input_type_is_list is False:
+ return judger_response_item[0]
+ return judger_response_item
diff --git a/xtuner/v1/ray/judger/gsm8k.py b/xtuner/v1/ray/judger/gsm8k.py
index 15599c3e5..94a5c54eb 100644
--- a/xtuner/v1/ray/judger/gsm8k.py
+++ b/xtuner/v1/ray/judger/gsm8k.py
@@ -81,6 +81,7 @@ def compute_reward(response, label, extra_info):
class GSM8KJudgerConfig(BaseModel):
"""Configuration for the GSM8K judger."""
+ judger_name: str = "gsm8k_judger"
extra_info: dict = {"score": 1, "format_score": 0}
def build(self):
@@ -89,4 +90,4 @@ def build(self):
Returns:
NativeJudger: An instance of the NativeJudger configured for GSM8K.
"""
- return NativeJudger(reward_func=compute_reward, extra_info=self.extra_info)
+ return NativeJudger(judger_name=self.judger_name, reward_func=compute_reward, extra_info=self.extra_info)
diff --git a/xtuner/v1/ray/judger/native.py b/xtuner/v1/ray/judger/native.py
index ca689ec76..a84779ff2 100644
--- a/xtuner/v1/ray/judger/native.py
+++ b/xtuner/v1/ray/judger/native.py
@@ -3,6 +3,9 @@
import httpx
+from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLJudgerResponseItem
+from xtuner.v1.utils import get_logger
+
class NativeJudger:
"""Base class for judgers, providing a standard interface for executing a
@@ -16,6 +19,7 @@ class NativeJudger:
def __init__(
self,
+ judger_name: str = "native_judger",
reward_func: Optional[Callable] = None,
remote_url: Optional[str] = None,
preprocess_func: Optional[Callable] = None,
@@ -47,7 +51,7 @@ def __init__(
"""
if (reward_func is None and remote_url is None) or (reward_func is not None and remote_url is not None):
raise ValueError("Exactly one of 'reward_func' or 'remote_url' must be provided.")
-
+ self.judger_name = judger_name
self.extra_info = extra_info
self.reward_func = reward_func
self.remote_url = remote_url
@@ -63,20 +67,31 @@ def __init__(
elif self.remote_url:
self.http_client = httpx.AsyncClient(timeout=request_timeout)
self.execute_func = self._remote_executor
+ self.logger = get_logger(__name__)
- def _default_preprocess(self, responses: str | List[str], labels: str | List[str]) -> Any:
+ def _default_preprocess(self, data_item: List[RLDataFlowItem], extra_info: dict) -> Any:
"""Default preprocessing function.
Args:
- responses (str | List[str]): The model's response(s).
- labels (str | List[str]): The ground-truth label(s).
+ data_item (RLDataFlowItem | List[RLDataFlowItem]): The data item(s) to preprocess.
Returns:
Any: A dictionary containing the responses, labels, and extra info.
"""
- return {"response": responses, "label": labels, "extra_info": self.extra_info}
- def _default_postprocess(self, result: Any) -> Any:
+ assert len(data_item) == 1, "Default preprocess only supports single data item."
+ # TODO: Support batch reward calculation via API server
+ response = data_item[0].env.rollout.response
+ assert data_item[0].data.reward_model is not None
+ label = data_item[0].data.reward_model["ground_truth"]
+ return {
+ "response": response,
+ "label": label,
+ "extra_info": extra_info,
+ }
+
+ def _default_postprocess(self, result: Any) -> List[RLJudgerResponseItem]:
+ ## 将结果包装成 RLJudgerResponseItem
"""Default postprocessing function.
Args:
@@ -85,9 +100,12 @@ def _default_postprocess(self, result: Any) -> Any:
Returns:
Any: The result, unchanged.
"""
- return result
+ if not isinstance(result, list):
+ result = [result]
+ judger_response_item = [RLJudgerResponseItem(reward={self.judger_name: result[i]}) for i in range(len(result))]
+ return judger_response_item
- async def _local_executor(self, responses: str | List[str], labels: str | List[str]) -> Any:
+ async def _local_executor(self, data_item: List[RLDataFlowItem]) -> List[RLJudgerResponseItem]:
"""Executes the reward function locally.
Args:
@@ -98,14 +116,20 @@ async def _local_executor(self, responses: str | List[str], labels: str | List[s
Any: The postprocessed result of the reward function.
"""
assert self.reward_func is not None, "reward_func cannot be None for local execution."
- kwargs = self.preprocess_func(responses, labels)
+ # 记录每个judger请求的uid, 方便后续结果合并
+ uid_list = [item.uid.observation_id for item in data_item]
+ kwargs = self.preprocess_func(data_item, self.extra_info)
if inspect.iscoroutinefunction(self.reward_func):
result = await self.reward_func(**kwargs)
else:
result = self.reward_func(**kwargs)
- return self.postprocess_func(result)
- async def _remote_executor(self, responses: str | List[str], labels: str | List[str]) -> Any:
+ result = self.postprocess_func(result)
+ for i in range(len(result)):
+ result[i].uid = uid_list[i]
+ return result
+
+ async def _remote_executor(self, data_item: List[RLDataFlowItem]) -> List[RLJudgerResponseItem]:
"""Executes the reward function by calling a remote service.
Args:
@@ -119,17 +143,19 @@ async def _remote_executor(self, responses: str | List[str], labels: str | List[
assert self.remote_url is not None and self.http_client is not None, (
"remote_url cannot be None for remote execution."
)
- payload = self.preprocess_func(responses, labels)
+ payload = self.preprocess_func(data_item, self.extra_info)
try:
response = await self.http_client.post(self.remote_url, json=payload)
response.raise_for_status()
result = response.json()
+ # 重要,必须加
+ result["uid"] = data_item[0].uid.observation_id
return self.postprocess_func(result)
except httpx.RequestError as exc:
- print(f"An error occurred while requesting {exc.request.url}: {exc}")
- return None
+ self.logger.error(f"An error occurred while requesting {exc.request.url}: {exc}")
+ return []
- async def judge(self, responses: str | List[str], labels: str | List[str]) -> Any:
+ async def judge(self, data_item: List[RLDataFlowItem]) -> List[RLJudgerResponseItem]:
"""The main public method to run the judging pipeline.
Args:
@@ -145,4 +171,4 @@ async def judge(self, responses: str | List[str], labels: str | List[str]) -> An
"""
if self.execute_func is None:
raise RuntimeError("Judger is not properly initialized.")
- return await self.execute_func(responses, labels)
+ return await self.execute_func(data_item)
diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py
index df52aeab7..670b17c54 100644
--- a/xtuner/v1/ray/rollout/controller.py
+++ b/xtuner/v1/ray/rollout/controller.py
@@ -7,9 +7,10 @@
from fastapi import FastAPI
from transformers import AutoTokenizer
+from xtuner.v1.data_proto.rl_data import RLRolloutRequestItem, RLRolloutResponseItem, SampleParams
from xtuner.v1.ray.config.worker import RolloutConfig
-from .worker import RolloutRequest, RolloutResponse, RolloutWorker, SampleParams
+from .worker import RolloutWorker
@ray.remote
@@ -113,7 +114,8 @@ async def rollout(
sample_params: Optional[SampleParams] = None,
extra_params: dict = dict(),
format: str = "openai",
- ) -> RolloutResponse:
+ ) -> RLRolloutResponseItem:
+ # 这个函数接受标准的openapi chat create接口,所以不需要再额外定义输入的形式
"""Perform a rollout using one of the workers in a round-robin fashion.
Args:
@@ -154,7 +156,7 @@ def start_api_server(self, host: str = "0.0.0.0", port: int = 8000):
port = self.config.api_port if self.config.api_port else port
@app.post("/v1/chat/completions")
- async def chat_completions(request: RolloutRequest) -> RolloutResponse:
+ async def chat_completions(request: RLRolloutRequestItem) -> RLRolloutResponseItem:
response = await self.rollout(
prompt=request.messages,
tools=request.tools,
diff --git a/xtuner/v1/ray/rollout/lmdeploy.py b/xtuner/v1/ray/rollout/lmdeploy.py
index dbd4c4752..d817687f0 100644
--- a/xtuner/v1/ray/rollout/lmdeploy.py
+++ b/xtuner/v1/ray/rollout/lmdeploy.py
@@ -105,8 +105,19 @@ async def _create_request(
"tool_choice": tool_choice,
"stream": True,
}
+
+ # todo(@duanyanhui): it will be supported after lmdeploy supports stop tokens in release version.
+ # if self.config.return_stop_tokens:
+ # payload["include_stop_str_in_output"] = True
+
payload.update(sample_params)
payload.update(extra_params)
+
+ if "logprobs" in payload and payload["logprobs"]:
+ self.logger.warning(
+ "LMDeploy can't return stop tokens' logprobs now. It will be supported in next release version."
+ )
+
req = self.client.build_request(
"POST",
url,
@@ -215,16 +226,20 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace:
tp=tp_size,
ep=ep_size,
dp=dp_size,
+ max_batch_size=self.config.rollout_max_batch_size,
empty_init=self.config.skip_load_weights,
distributed_executor_backend=distributed_executor_backend,
mp_engine_backend="ray", # force ray to pass placement group
device_type=accelerator_to_device_type[self.accelerator],
+ logprobs_mode="raw_logprobs",
)
if backend == "pytorch"
else TurbomindEngineConfig(
tp=tp_size,
+ max_batch_size=self.config.rollout_max_batch_size,
devices=[bundle_idxs % self.config.gpus_per_node for bundle_idxs in self.engine_bundle_idxs],
empty_init=self.config.skip_load_weights,
+ logprobs_mode="raw_logprobs",
)
)
if backend == "pytorch" and self.accelerator == "NPU":
@@ -276,8 +291,6 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace:
if "backend" in lmdeploy_config_kwargs:
lmdeploy_config_kwargs.pop("backend")
- lmdeploy_config_kwargs["log_level"] = "CRITICAL" # disable logging
-
return Namespace(
model_path=self.config.model_path,
model_name=self.model_name,
diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py
index 8560c3a25..28cd48854 100644
--- a/xtuner/v1/ray/rollout/worker.py
+++ b/xtuner/v1/ray/rollout/worker.py
@@ -9,84 +9,15 @@
import httpx
import ray
import requests # type: ignore[import-untyped]
-from cyclopts import Parameter
-from pydantic import BaseModel, Field
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
-from typing_extensions import Annotated
+from xtuner.v1.data_proto.rl_data import RLRolloutResponseItem
from xtuner.v1.ray import find_master_addr_and_port
from xtuner.v1.ray.accelerator import AutoAcceleratorWorkers, SingleAcceleratorWorker
from xtuner.v1.ray.config import RolloutConfig
from xtuner.v1.utils import get_logger
-class SampleParams(BaseModel):
- """Sampling parameters configuration for text generation in XTuner.
-
- Args:
- n (int): Number of samples to generate for each input. Defaults to 1.
- top_k (int): Number of highest probability vocabulary tokens to keep for
- top-k filtering. Set to 0 to disable. Defaults to 0.
- top_p (float): Cumulative probability threshold for nucleus (top-p) sampling.
- Defaults to 1.0.
- temperature (float): Sampling temperature to control randomness. Lower values
- make output more deterministic. Defaults to 1.0.
- repetition_penalty (float): Penalty applied to tokens that have already
- appeared in the sequence. Defaults to 1.0 (no penalty).
- presence_penalty (float): Penalty applied based on token presence in the
- generated text. Defaults to 0.0.
- frequency_penalty (float): Penalty applied based on token frequency in the
- generated text. Defaults to 0.0.
- min_tokens (int): Minimum number of tokens to generate before considering
- stop conditions. Defaults to 0.
- max_tokens (int): Maximum number of tokens to generate. Defaults to 2048.
- stops (List[str]): List of string sequences that will stop generation when
- encountered. Defaults to empty list.
- stop_token_ids (List[int]): List of token IDs that will stop generation when
- encountered. Defaults to empty list.
- logprobs (int): Number of log probabilities to return for each token.
- Set to 0 to disable. Defaults to 0.
- skip_special_tokens (bool): Whether to skip special tokens during decoding.
- Defaults to True.
- do_sample (bool): Whether to use sampling (True) or greedy decoding (False).
- Defaults to True.
- """
-
- n: Annotated[int, Parameter(help="Number of samples to generate.")] = 1
- top_k: Annotated[
- int, Parameter(help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
- ] = 0
- top_p: Annotated[float, Parameter(help="The cumulative probability for nucleus sampling.")] = 1.0
- temperature: Annotated[float, Parameter(help="The value used to module the next token probabilities.")] = 1.0
- repetition_penalty: Annotated[float, Parameter(help="The parameter for repetition penalty.")] = 1.0
- presence_penalty: Annotated[float, Parameter(help="The parameter for presence penalty.")] = 0.0
- frequency_penalty: Annotated[float, Parameter(help="The parameter for frequency penalty.")] = 0.0
- min_tokens: Annotated[int, Parameter(help="Minimum number of tokens to generate.")] = 0
- max_tokens: Annotated[int, Parameter(help="Maximum number of tokens to generate.")] = 2048
- stops: Annotated[List[str], Parameter(help="List of stop sequences.")] = []
- stop_token_ids: Annotated[List[int], Parameter(help="List of stop token IDs.")] = []
- logprobs: Annotated[int, Parameter(help="Number of log probabilities to return.")] = 0
- skip_special_tokens: Annotated[bool, Parameter(help="Whether to skip special tokens.")] = True
- do_sample: Annotated[bool, Parameter(help="Whether to sample or not.")] = True
-
-
-class RolloutRequest(BaseModel):
- messages: Union[str, List[Dict[str, Any]]]
- tools: List = Field(default_factory=list)
- tool_choice: str = "auto"
- sample_params: SampleParams = Field(default_factory=SampleParams)
- extra_params: Dict[str, Any] = Field(default_factory=dict)
-
-
-class RolloutResponse(BaseModel):
- response: str = ""
- logprobs: float = 0.0
- finish_reason: str = ""
- reasoning_content: str = ""
- usage: dict = Field(default_factory=dict)
- tool_calls: List[str] = Field(default_factory=list)
-
-
class RolloutWorker(SingleAcceleratorWorker):
"""Base class for a rollout worker that runs an inference server.
@@ -331,9 +262,13 @@ async def rollout_task(
sample_params: dict,
extra_params: dict,
format: str,
- ) -> RolloutResponse:
+ ) -> RLRolloutResponseItem:
uid = str(uuid.uuid4())
response = None
+ failed_rollout_response = RLRolloutResponseItem(
+ response="",
+ finish_reason="failed",
+ )
try:
if format == "openai":
openai_prompts, openai_tools = prompts, tools
@@ -348,17 +283,14 @@ async def rollout_task(
extra_params=extra_params,
)
self.logger.debug(f" +++ send request {uid} to worker: {self.rank}")
-
- failed_rollout_response = RolloutResponse(
- response="",
- finish_reason="failed",
- )
if response.status_code != 200:
error_body = await response.atext()
self.logger.error(f"Request {uid} failed with status {response.status_code}: {error_body}")
return failed_rollout_response
last_trajectory = ""
+ last_token_ids = []
+ last_logprobs = []
finish_reason = ""
async for chunk in response.aiter_lines():
@@ -377,12 +309,20 @@ async def rollout_task(
chunk_data = json.loads(chunk_data_str)
delta_content = chunk_data["choices"][0]["delta"].get("content")
last_trajectory = last_trajectory + delta_content if delta_content else last_trajectory
+ last_token_id = chunk_data["choices"][0]["delta"].get("gen_tokens")
+ if last_token_id is not None:
+ last_token_ids.extend(last_token_id)
finish_reason = chunk_data["choices"][0].get("finish_reason")
-
+ logprobs_content = chunk_data["choices"][0]["logprobs"]
+ if logprobs_content is not None:
+ for content_item in logprobs_content["content"]:
+ last_logprobs.append(content_item["logprob"])
# todo(@duanyanhui): remove appending stop tokens manually after lmdeploy support return stop_token_ids.
if finish_reason == "stop":
assert len(sample_params["stops"]) == 1
last_trajectory += sample_params["stops"][0]
+ if len(last_token_ids) > 0:
+ last_token_ids.append(sample_params["stop_token_ids"][0])
except json.JSONDecodeError as e:
self.logger.error(f"JSON decode error for chunk in request {uid}: {chunk}, error: {e}")
@@ -394,10 +334,13 @@ async def rollout_task(
assert finish_reason in ["stop", "length", "tool_call", "paused", "failed"], (
f"Unexpected finish_reason: {finish_reason}"
)
-
- rollout_response = RolloutResponse(
+ self.logger.debug(f" --- request {uid} finished with reason: {finish_reason}")
+ rollout_response = RLRolloutResponseItem(
response=last_trajectory,
+ response_ids=last_token_ids if len(last_token_ids) > 0 else None,
+ num_return_tokens=len(last_token_ids),
finish_reason=finish_reason,
+ logprobs=last_logprobs,
)
return rollout_response
@@ -420,7 +363,7 @@ async def rollout(
sample_params: dict = dict(),
extra_params: dict = dict(),
format: str = "openai",
- ) -> RolloutResponse:
+ ) -> RLRolloutResponseItem:
"""Public method to initiate a rollout.
Args:
diff --git a/xtuner/v1/train/cli/grpo.py b/xtuner/v1/train/cli/grpo.py
index 41c6ec2b8..ba95e349b 100644
--- a/xtuner/v1/train/cli/grpo.py
+++ b/xtuner/v1/train/cli/grpo.py
@@ -3,7 +3,6 @@
import ray
-from transformers import AutoTokenizer
from xtuner.v1.config import (
AdamWConfig,
FSDPConfig,
@@ -79,8 +78,8 @@ def main(args):
)
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
- gsm8k_judger_config = GSM8KJudgerConfig()
- judger_cfg = JudgerConfig(reward_judger_configs={"openai/gsm8k": gsm8k_judger_config})
+ gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
+ judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config])
train_dataset_cfg = [
{
"dataset": DatasetConfig(name="gsm8k", anno_path=args.data_path, sample_ratio=1.0),
@@ -102,7 +101,6 @@ def main(args):
collator="fake_collator",
pack_level="none",
)
- tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
if eval_dataset_cfg:
evaluator_cfg = EvaluatorConfig(
dataset_cfg=eval_dataset_cfg,
diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py
index 08d7059b9..3832d73b1 100644
--- a/xtuner/v1/train/rl_trainer.py
+++ b/xtuner/v1/train/rl_trainer.py
@@ -321,16 +321,16 @@ def _prepare_train_data(self, data_groups, pack_max_length):
data_batches = []
for group in data_groups:
prompt = self.tokenizer.apply_chat_template(
- group[0]["messages"], add_generation_prompt=True, tokenize=False
+ group[0].data.messages, add_generation_prompt=True, tokenize=False
)
prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].flatten().tolist()
- rewards = [data["reward"] for data in group]
+ rewards = [data.env.judger.reward["weighted_reward"] for data in group]
rewards = torch.tensor(rewards, dtype=torch.float32)
advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8)
prompt_repeat_k = len(group)
for i in range(prompt_repeat_k):
- item = group[i]["response_str"]
+ item = group[i].env.rollout.response
response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist()
input_ids = prompt_ids + response_ids
shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + [-100]
@@ -355,12 +355,12 @@ def _save_trajectories(self, data_groups, save_path):
response_list = []
reward_list = []
for data in group:
- response_list.append(data["response_str"])
- reward_list.append(data["reward"])
+ response_list.append(data.env.rollout.response)
+ reward_list.append(data.env.judger.reward["weighted_reward"])
item = {
- "messages": group[0]["messages"],
+ "messages": group[0].data.messages,
"response": response_list,
- "label": group[0]["reward_model"]["ground_truth"],
+ "label": group[0].data.reward_model["ground_truth"],
"reward": reward_list,
}
json.dump(item, f)
diff --git a/xtuner/v1/utils/rl_test_utils.py b/xtuner/v1/utils/rl_test_utils.py
index 428a12fbc..d9c15cfa5 100644
--- a/xtuner/v1/utils/rl_test_utils.py
+++ b/xtuner/v1/utils/rl_test_utils.py
@@ -7,6 +7,7 @@
from fastapi import FastAPI
from pydantic import BaseModel, Field
+from xtuner.v1.data_proto.rl_data import RLJudgerResponseItem
from xtuner.v1.ray.judger.gsm8k import compute_reward
from xtuner.v1.ray.judger.native import NativeJudger
@@ -70,14 +71,23 @@ def stop(self):
def custom_postprocessor_for_gsm8k(result):
- return result["reward"]
+ if not isinstance(result, list):
+ result = [result]
+ judger_response_item = [
+ RLJudgerResponseItem(uid=result[i]["uid"], reward={"reward": result[i]["reward"]}) for i in range(len(result))
+ ]
+ return judger_response_item
class GSM8KRemoteJudgerConfig(BaseModel):
+ judger_name: str
remote_url: str
extra_info: dict = {"score": 1, "format_score": 0}
def build(self):
return NativeJudger(
- remote_url=self.remote_url, postprocess_func=custom_postprocessor_for_gsm8k, extra_info=self.extra_info
+ judger_name=self.judger_name,
+ remote_url=self.remote_url,
+ postprocess_func=custom_postprocessor_for_gsm8k,
+ extra_info=self.extra_info,
)