Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions ci/scripts/test_dapo_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import os
import re
from pathlib import Path
import ray
import argparse

import matplotlib.pyplot as plt
import numpy as np
import torch.distributed as dist
from transformers import AutoTokenizer

from xtuner.v1.config import (
AdamWConfig,
FSDPConfig,
LRConfig,
)
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig
from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig
from xtuner.v1.ray.accelerator import AcceleratorResourcesConfig
from xtuner.v1.ray.config.worker import RolloutConfig
from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig
from xtuner.v1.ray.rollout import SampleParams
from xtuner.v1.ray.evaluator import EvaluatorConfig
from xtuner.v1.datasets import RLTextTokenizeFnConfig
from xtuner.v1.config import (
AdamWConfig,
FSDPConfig,
LRConfig,
)
from xtuner.v1.ray.judger.controller import JudgerConfig
from xtuner.v1.rl.base import WorkerConfig
from xtuner.v1.rl.grpo import GRPOLossConfig
# from xtuner.v1.rl.grpo import GRPOLossConfig, WorkerConfig
# from xtuner.v1.rl.grpo.config import WorkerConfig, LossConfig
# from xtuner.v1.rl.grpo.trainer import Trainer
from xtuner.v1.train.rl_trainer import RLTrainer

MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"]
TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"]
os.environ['XTUNER_USE_FA3'] = "1"

def parse_args():
parser = argparse.ArgumentParser(description="VLLM Rollout Test Script")
parser.add_argument("--total-epochs", type=int)
parser.add_argument("--work-dir", type=str, default="work_dir")
parser.add_argument("--model-path", type=str, default=MODEL_PATH)
parser.add_argument("--data-path", type=str, default=TRAIN_DATA_PATH)
parser.add_argument("--eval-data-path", type=str, default=TEST_DATA_PATH)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--gpus-per-node", type=int, default=8)
parser.add_argument("--rollout-global-batch-size", type=int, default=128)
parser.add_argument("--train-optimizer-steps", type=int, default=1)
parser.add_argument("--max-concurrent", type=int, default=8)
parser.add_argument("--prompt-repeat-k", type=int, default=8)
parser.add_argument("--pack-max-length", type=int, default=8192)
parser.add_argument("--max-prompt-length", type=int, default=512)
parser.add_argument("--max-response-length", type=int, default=1024)
parser.add_argument("--optimizer-disable-foreach", action="store_true") # save memory usage during opt.step()
parser.add_argument("--policy-loss-type", type=str, default="vanilla")
parser.add_argument("--enable-evaluate", action="store_true")
parser.add_argument("--evaluate-step", type=int, default=1)
parser.add_argument("--evaluate-ratio", type=float, default=1)
parser.add_argument("--ray-cluster-url", type=str, default="")
return parser.parse_args()


def main(args):
if args.ray_cluster_url == "":
ray.init(num_cpus=128, ignore_reinit_error=True)
else:
ray.init(address=args.ray_cluster_url, ignore_reinit_error=True)
load_from = args.model_path
resources = AcceleratorResourcesConfig(
accelerator="GPU",
num_accelerators_per_worker=1,
num_cpus_per_worker=12,
num_workers=args.num_workers,
cpu_memory_per_worker=16 * 1024**3, # 16 GB
)
rollout_config = RolloutConfig(
env="test_env",
model_path=args.model_path,
model_name=os.path.basename(args.model_path).lower(),
tokenizer_path=args.model_path,
rollout_cross_node_comm=False,
tensor_parallel_size=2,
expert_parallel_size=1,
gpus_per_node=args.gpus_per_node, # gpu: 8, npu: 16
dtype="bfloat16",
skip_load_weights=False,
)
dataflow_config = DataFlowConfig(
env="test",
max_concurrent=args.max_concurrent,
prompt_repeat_k=args.prompt_repeat_k,
global_batch_size=args.rollout_global_batch_size,
sample_params=SampleParams(
max_tokens=args.max_response_length,
# ###### greedy
# top_k=20,
# # temperature=1e-6,
##########
top_k=0,
top_p=1.0,
temperature=1.0,

min_tokens=0,
# stop_token_ids= [],
# logprobs= 0,
# skip_special_tokens= True,
do_sample=True,
),
)
# from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
# gsm8k_judger_config = GSM8KJudgerConfig()
# judger_cfg = JudgerConfig(
# reward_judger_configs={"openai/gsm8k": gsm8k_judger_config}
# )
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig
dapomath_judger_config = DapoMathJudgerConfig(True, args.max_response_length, 4096, 1.0, tokenizer)
judger_cfg = JudgerConfig(
reward_judger_configs={"math_dapo": dapomath_judger_config}
)
train_dataset_cfg = [
{
"dataset": DatasetConfig(name="dapo_math",
anno_path=args.data_path,
sample_ratio=1.0),
"tokenize_fn": RLTextTokenizeFnConfig(max_length=args.max_prompt_length),
},
]
eval_dataset_cfg = [
{
"dataset": DatasetConfig(name="gsm8k",
anno_path=args.eval_data_path,
sample_ratio=1.0),
"tokenize_fn": RLTextTokenizeFnConfig(max_length=args.max_prompt_length),
},
]
dataloader_cfg = DataloaderConfig(
pack_max_length=args.pack_max_length,
collator='fake_collator',
pack_level='none',
)
# tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
evaluator_cfg = EvaluatorConfig(
dataset_cfg=eval_dataset_cfg,
tokenizer=tokenizer,
max_concurrent=args.max_concurrent,
eval_sample_ratio=args.evaluate_ratio,
evaluate_step=args.evaluate_step,
compute_metric_func=None
)
replay_buffer_cfg = ReplayBufferConfig(
dataset_cfg=train_dataset_cfg,
dataloader_cfg=dataloader_cfg,
tokenizer=tokenizer,
postprocessor=None
)
train_worker_cfg: WorkerConfig = WorkerConfig(
# model_cfg=Qwen3Dense8BConfig(),
model_cfg=Qwen2Dense7BConfig(),
optim_cfg=AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False if args.optimizer_disable_foreach else None),
loss_cfg=GRPOLossConfig(
policy_loss_cfg=dict(
cliprange_high=0.28,
cliprange_low=0.2,
loss_type=args.policy_loss_type,
),
ignore_idx=-100,
use_kl_loss=False,
kl_loss_coef=0.0,
kl_loss_type="low_var_kl",
mode="chunk",
chunk_size=512),
lr_cfg=LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6),
fsdp_cfg=FSDPConfig(
torch_compile=False,
cpu_offload=False,
ep_size=1,
),
load_from=args.model_path,
sp_size=1,
optimizer_steps=args.train_optimizer_steps,
pack_max_length=args.pack_max_length,
)
trainer = RLTrainer(
load_from=load_from,
resources=resources,
rollout_config=rollout_config,
dataflow_config=dataflow_config,
judger_config=judger_cfg,
replay_buffer_config=replay_buffer_cfg,
evaluator_config=evaluator_cfg,
train_worker_cfg=train_worker_cfg,
tokenizer_path=args.model_path,
work_dir=args.work_dir,
total_epochs=args.total_epochs,
enable_evaluate=args.enable_evaluate
)
trainer.fit()


if __name__ == "__main__":
args = parse_args()
main(args)
30 changes: 30 additions & 0 deletions ci/scripts/test_dapo_trainer_bash_7B_nogroup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
set -ex

export ROLLOUT_MODEL_PATH="/cpfs01/shared/llm_ddd/lishuaibin/ckpt/Qwen/Qwen2.5-Math-7B"
export ROLLOUT_DATA_PATH="/cpfs01/shared/llm_razor/caoweihan/dapo-math-17k.jsonl"
export ROLLOUT_TEST_DATA_PATH="/cpfs01/shared/llm_razor/huanghaian/code/refactor_xtuner/gsm8k/test.jsonl"
export XTUNER_USE_LMDEPLOY=1
export XTUNER_USE_FA3=1
export PYTHONPATH='/cpfs01/shared/llm_razor/caoweihan/projects/lmdeploy':'/cpfs01/shared/llm_ddd/caoweihan/projects/Liger-Kernel/src/':'.':$PYTHONPATH
export UVICORN_LOG_LEVEL="CRITICAl"
export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True'

OUTPUT_DIR='work_dirs/dapo_math_7B_newlmdeploy_nogroup'
if [ ! -d "$OUTPUT_DIR" ]; then
mkdir -p "$OUTPUT_DIR"
fi

python ci/scripts/test_dapo_trainer.py \
--total-epochs 1 \
--work-dir "$OUTPUT_DIR" \
--num-workers 8 \
--gpus-per-node 8 \
--rollout-global-batch-size 512 \
--train-optimizer-steps 16 \
--max-concurrent 64 \
--prompt-repeat-k 16 \
--pack-max-length 32768 \
--max-prompt-length 2048 \
--max-response-length 8192 \
--optimizer-disable-foreach \
2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt"
3 changes: 2 additions & 1 deletion ci/scripts/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig
from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig
from xtuner.v1.ray.accelerator import AcceleratorResourcesConfig
from xtuner.v1.ray.config.worker import RolloutConfig
from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig
Expand Down Expand Up @@ -134,7 +135,7 @@ def main(args):
postprocessor=None
)
train_worker_cfg: WorkerConfig = WorkerConfig(
model_cfg=Qwen3Dense8BConfig(),
model_cfg=Qwen2Dense7BConfig(),
optim_cfg=AdamWConfig(lr=1e-6, foreach=False if args.optimizer_disable_foreach else None),
loss_cfg=GRPOLossConfig(
policy_loss_cfg=dict(
Expand Down
Loading
Loading