diff --git a/ci/scripts/test_dapo_trainer.py b/ci/scripts/test_dapo_trainer.py new file mode 100755 index 000000000..fabc83667 --- /dev/null +++ b/ci/scripts/test_dapo_trainer.py @@ -0,0 +1,209 @@ +import os +import re +from pathlib import Path +import ray +import argparse + +import matplotlib.pyplot as plt +import numpy as np +import torch.distributed as dist +from transformers import AutoTokenizer + +from xtuner.v1.config import ( + AdamWConfig, + FSDPConfig, + LRConfig, +) +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig +from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig +from xtuner.v1.ray.accelerator import AcceleratorResourcesConfig +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig +from xtuner.v1.ray.rollout import SampleParams +from xtuner.v1.ray.evaluator import EvaluatorConfig +from xtuner.v1.datasets import RLTextTokenizeFnConfig +from xtuner.v1.config import ( + AdamWConfig, + FSDPConfig, + LRConfig, +) +from xtuner.v1.ray.judger.controller import JudgerConfig +from xtuner.v1.rl.base import WorkerConfig +from xtuner.v1.rl.grpo import GRPOLossConfig +# from xtuner.v1.rl.grpo import GRPOLossConfig, WorkerConfig +# from xtuner.v1.rl.grpo.config import WorkerConfig, LossConfig +# from xtuner.v1.rl.grpo.trainer import Trainer +from xtuner.v1.train.rl_trainer import RLTrainer + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] +os.environ['XTUNER_USE_FA3'] = "1" + +def parse_args(): + parser = argparse.ArgumentParser(description="VLLM Rollout Test Script") + parser.add_argument("--total-epochs", type=int) + parser.add_argument("--work-dir", type=str, default="work_dir") + parser.add_argument("--model-path", type=str, default=MODEL_PATH) + parser.add_argument("--data-path", type=str, default=TRAIN_DATA_PATH) + parser.add_argument("--eval-data-path", type=str, default=TEST_DATA_PATH) + parser.add_argument("--num-workers", type=int, default=8) + parser.add_argument("--gpus-per-node", type=int, default=8) + parser.add_argument("--rollout-global-batch-size", type=int, default=128) + parser.add_argument("--train-optimizer-steps", type=int, default=1) + parser.add_argument("--max-concurrent", type=int, default=8) + parser.add_argument("--prompt-repeat-k", type=int, default=8) + parser.add_argument("--pack-max-length", type=int, default=8192) + parser.add_argument("--max-prompt-length", type=int, default=512) + parser.add_argument("--max-response-length", type=int, default=1024) + parser.add_argument("--optimizer-disable-foreach", action="store_true") # save memory usage during opt.step() + parser.add_argument("--policy-loss-type", type=str, default="vanilla") + parser.add_argument("--enable-evaluate", action="store_true") + parser.add_argument("--evaluate-step", type=int, default=1) + parser.add_argument("--evaluate-ratio", type=float, default=1) + parser.add_argument("--ray-cluster-url", type=str, default="") + return parser.parse_args() + + +def main(args): + if args.ray_cluster_url == "": + ray.init(num_cpus=128, ignore_reinit_error=True) + else: + ray.init(address=args.ray_cluster_url, ignore_reinit_error=True) + load_from = args.model_path + resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_accelerators_per_worker=1, + num_cpus_per_worker=12, + num_workers=args.num_workers, + cpu_memory_per_worker=16 * 1024**3, # 16 GB + ) + rollout_config = RolloutConfig( + env="test_env", + model_path=args.model_path, + model_name=os.path.basename(args.model_path).lower(), + tokenizer_path=args.model_path, + rollout_cross_node_comm=False, + tensor_parallel_size=2, + expert_parallel_size=1, + gpus_per_node=args.gpus_per_node, # gpu: 8, npu: 16 + dtype="bfloat16", + skip_load_weights=False, + ) + dataflow_config = DataFlowConfig( + env="test", + max_concurrent=args.max_concurrent, + prompt_repeat_k=args.prompt_repeat_k, + global_batch_size=args.rollout_global_batch_size, + sample_params=SampleParams( + max_tokens=args.max_response_length, + # ###### greedy + # top_k=20, + # # temperature=1e-6, + ########## + top_k=0, + top_p=1.0, + temperature=1.0, + + min_tokens=0, + # stop_token_ids= [], + # logprobs= 0, + # skip_special_tokens= True, + do_sample=True, + ), + ) + # from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig + # gsm8k_judger_config = GSM8KJudgerConfig() + # judger_cfg = JudgerConfig( + # reward_judger_configs={"openai/gsm8k": gsm8k_judger_config} + # ) + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig + dapomath_judger_config = DapoMathJudgerConfig(True, args.max_response_length, 4096, 1.0, tokenizer) + judger_cfg = JudgerConfig( + reward_judger_configs={"math_dapo": dapomath_judger_config} + ) + train_dataset_cfg = [ + { + "dataset": DatasetConfig(name="dapo_math", + anno_path=args.data_path, + sample_ratio=1.0), + "tokenize_fn": RLTextTokenizeFnConfig(max_length=args.max_prompt_length), + }, + ] + eval_dataset_cfg = [ + { + "dataset": DatasetConfig(name="gsm8k", + anno_path=args.eval_data_path, + sample_ratio=1.0), + "tokenize_fn": RLTextTokenizeFnConfig(max_length=args.max_prompt_length), + }, + ] + dataloader_cfg = DataloaderConfig( + pack_max_length=args.pack_max_length, + collator='fake_collator', + pack_level='none', + ) + # tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + evaluator_cfg = EvaluatorConfig( + dataset_cfg=eval_dataset_cfg, + tokenizer=tokenizer, + max_concurrent=args.max_concurrent, + eval_sample_ratio=args.evaluate_ratio, + evaluate_step=args.evaluate_step, + compute_metric_func=None + ) + replay_buffer_cfg = ReplayBufferConfig( + dataset_cfg=train_dataset_cfg, + dataloader_cfg=dataloader_cfg, + tokenizer=tokenizer, + postprocessor=None + ) + train_worker_cfg: WorkerConfig = WorkerConfig( + # model_cfg=Qwen3Dense8BConfig(), + model_cfg=Qwen2Dense7BConfig(), + optim_cfg=AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False if args.optimizer_disable_foreach else None), + loss_cfg=GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=args.policy_loss_type, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode="chunk", + chunk_size=512), + lr_cfg=LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6), + fsdp_cfg=FSDPConfig( + torch_compile=False, + cpu_offload=False, + ep_size=1, + ), + load_from=args.model_path, + sp_size=1, + optimizer_steps=args.train_optimizer_steps, + pack_max_length=args.pack_max_length, + ) + trainer = RLTrainer( + load_from=load_from, + resources=resources, + rollout_config=rollout_config, + dataflow_config=dataflow_config, + judger_config=judger_cfg, + replay_buffer_config=replay_buffer_cfg, + evaluator_config=evaluator_cfg, + train_worker_cfg=train_worker_cfg, + tokenizer_path=args.model_path, + work_dir=args.work_dir, + total_epochs=args.total_epochs, + enable_evaluate=args.enable_evaluate + ) + trainer.fit() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/ci/scripts/test_dapo_trainer_bash_7B_nogroup.sh b/ci/scripts/test_dapo_trainer_bash_7B_nogroup.sh new file mode 100755 index 000000000..58d7a9bd2 --- /dev/null +++ b/ci/scripts/test_dapo_trainer_bash_7B_nogroup.sh @@ -0,0 +1,30 @@ +set -ex + +export ROLLOUT_MODEL_PATH="/cpfs01/shared/llm_ddd/lishuaibin/ckpt/Qwen/Qwen2.5-Math-7B" +export ROLLOUT_DATA_PATH="/cpfs01/shared/llm_razor/caoweihan/dapo-math-17k.jsonl" +export ROLLOUT_TEST_DATA_PATH="/cpfs01/shared/llm_razor/huanghaian/code/refactor_xtuner/gsm8k/test.jsonl" +export XTUNER_USE_LMDEPLOY=1 +export XTUNER_USE_FA3=1 +export PYTHONPATH='/cpfs01/shared/llm_razor/caoweihan/projects/lmdeploy':'/cpfs01/shared/llm_ddd/caoweihan/projects/Liger-Kernel/src/':'.':$PYTHONPATH +export UVICORN_LOG_LEVEL="CRITICAl" +export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' + +OUTPUT_DIR='work_dirs/dapo_math_7B_newlmdeploy_nogroup' +if [ ! -d "$OUTPUT_DIR" ]; then + mkdir -p "$OUTPUT_DIR" +fi + +python ci/scripts/test_dapo_trainer.py \ + --total-epochs 1 \ + --work-dir "$OUTPUT_DIR" \ + --num-workers 8 \ + --gpus-per-node 8 \ + --rollout-global-batch-size 512 \ + --train-optimizer-steps 16 \ + --max-concurrent 64 \ + --prompt-repeat-k 16 \ + --pack-max-length 32768 \ + --max-prompt-length 2048 \ + --max-response-length 8192 \ + --optimizer-disable-foreach \ + 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" diff --git a/ci/scripts/test_grpo_trainer.py b/ci/scripts/test_grpo_trainer.py index 2414baa9f..e1ee4f4a7 100644 --- a/ci/scripts/test_grpo_trainer.py +++ b/ci/scripts/test_grpo_trainer.py @@ -16,6 +16,7 @@ ) from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig +from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig from xtuner.v1.ray.accelerator import AcceleratorResourcesConfig from xtuner.v1.ray.config.worker import RolloutConfig from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig @@ -134,7 +135,7 @@ def main(args): postprocessor=None ) train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=Qwen3Dense8BConfig(), + model_cfg=Qwen2Dense7BConfig(), optim_cfg=AdamWConfig(lr=1e-6, foreach=False if args.optimizer_disable_foreach else None), loss_cfg=GRPOLossConfig( policy_loss_cfg=dict( diff --git a/ci/scripts/test_lmdeploy_vllm.py b/ci/scripts/test_lmdeploy_vllm.py new file mode 100644 index 000000000..14313cd9b --- /dev/null +++ b/ci/scripts/test_lmdeploy_vllm.py @@ -0,0 +1,157 @@ +import argparse +import json +import os +import time +from pathlib import Path + +import ray +from transformers import AutoTokenizer + +from xtuner.v1.datasets.config import ( + DataloaderConfig, + DatasetConfig, +) +from xtuner.v1.ray.rollout import SampleParams +from xtuner.v1.datasets import RLTextTokenizeFnConfig +from xtuner.v1.ray.accelerator import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig +from xtuner.v1.ray.environment import SingleTurnEnvironment +from xtuner.v1.ray.judger import JudgerConfig + +import numpy as np + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] + +def parse_args(): + parser = argparse.ArgumentParser(description="Env Generate Test Script") + parser.add_argument("--work-dir", type=str, default="work_dir") + parser.add_argument("--global-batch-size", type=int, default=8) + parser.add_argument("--top-k", type=int, default=1) + parser.add_argument("--top-p", type=float, default=1) + parser.add_argument("--temperature", type=float, default=1) + parser.add_argument("--prompt-repeat-k", type=int, default=1) + parser.add_argument("--max-prompt-length", type=int, default=1) + parser.add_argument("--max-response-length", type=int, default=1) + parser.add_argument("--repeat-times", type=int, default=1) + parser.add_argument("--enable-partial-rollout", type=int, default=0) + parser.add_argument("--vllm", action="store_true") + return parser.parse_args() + +def main(): + args = parse_args() + os.makedirs(args.work_dir, exist_ok=True) + ray.init(num_cpus=80) + resources_cfg = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8, + cpu_memory_per_worker=16 * 1024**3, # 16 GB + ) + if args.vllm: + rollout_cfg = RolloutConfig( + env="test_env", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + tensor_parallel_size=1, + backend="vllm", + launch_server_method="multiprocessing", + ) + else: + rollout_cfg = RolloutConfig( + env="test_env", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + tensor_parallel_size=1, + ) + from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig + gsm8k_judger_config = GSM8KJudgerConfig() + judger_cfg = JudgerConfig( + reward_judger_configs={"openai/gsm8k": gsm8k_judger_config} + ) + + dataflow_cfg = DataFlowConfig( + env="test", + prompt_repeat_k=args.prompt_repeat_k, + global_batch_size=args.global_batch_size, + enable_partial_rollout=args.enable_partial_rollout, + sample_params=SampleParams( + max_tokens=args.max_response_length, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + min_tokens=0, + ), + ) + dataset_cfg = [ + { + "dataset": DatasetConfig(name="gsm8k", + anno_path=TRAIN_DATA_PATH, + sample_ratio=1.0), + "tokenize_fn": RLTextTokenizeFnConfig(max_length=args.max_prompt_length), + }, + ] + dataloader_cfg = DataloaderConfig( + pack_max_length=16384, + collator='fake_collator', + pack_level='none', + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + replay_buffer_cfg = ReplayBufferConfig( + dataset_cfg=dataset_cfg, + dataloader_cfg=dataloader_cfg, + tokenizer=tokenizer, + postprocessor=None + ) + pg = AutoAcceleratorWorkers.build_placement_group(resources_cfg) + test_env = SingleTurnEnvironment.remote( + "test_env", + pg, + rollout_cfg, + judger_cfg + ) + test_flow = DataFlow.remote("test_env", + dataflow_cfg, + replay_buffer_cfg, + test_env) + + for i in range(6): + rollout_data = ray.get(test_flow.run.remote()) + dataflow_state = ray.get(test_flow.state.remote()) + result_path = os.path.join(args.work_dir, f"vllm{args.vllm}_rollout_results_round{i}.jsonl") + with open(result_path, "w") as f: + for group in rollout_data: + group_prompt_len = [] + group_response_len = [] + group_reward_list = [] + for data in group: + promt_str = group[0]["messages"][0]['content'] + prompt_ids = tokenizer(promt_str, return_tensors="pt")["input_ids"].flatten().tolist() + group_prompt_len.append(len(prompt_ids)) + + response_str = data["response_str"] + response_ids = tokenizer(response_str, return_tensors="pt")["input_ids"].flatten().tolist() + group_response_len.append(len(response_ids)) + + group_reward_list.append(data["reward"]) + item = { + "messages": group[0]["messages"], + "response": response_str, + "label": group[0]["reward_model"]["ground_truth"], + "reward": data["reward"], + } + json.dump(item, f, ensure_ascii=False) + f.write('\n') + + print(f"test_time{i}===============================================") + print(f"prompt_len: mean {np.mean(group_prompt_len)} max {np.max(group_prompt_len)} min {np.min(group_prompt_len)} std {np.std(group_prompt_len)}") + print(f"response_len: mean {np.mean(group_response_len)} max {np.max(group_response_len)} min {np.min(group_response_len)} std {np.std(group_response_len)}") + print(f"Average reward: {np.mean(group_reward_list)}") + time.sleep(2) + ray.get(test_env.shutdown.remote()) + +if __name__ == "__main__": + main() diff --git a/ci/scripts/test_lmdeploy_vllm.sh b/ci/scripts/test_lmdeploy_vllm.sh new file mode 100644 index 000000000..af9269001 --- /dev/null +++ b/ci/scripts/test_lmdeploy_vllm.sh @@ -0,0 +1,53 @@ +set -ex + +# export ROLLOUT_MODEL_PATH="/cpfs01/shared/llm_ddd/opencompass/models/modelscope_hub/QwQ/Qwen3-30B-A3B-250425" +export ROLLOUT_MODEL_PATH="/cpfs01/shared/llm_razor/huanghaian/new_model/Qwen3-8B" +export ROLLOUT_MODEL_PATH="/cpfs01/shared/llm_ddd/lishuaibin/ckpt/Qwen/Qwen2.5-Math-7B" + +# export ROLLOUT_DATA_PATH="/cpfs01/shared/llm_razor/huanghaian/code/refactor_xtuner/gsm8k/train.jsonl" +export ROLLOUT_DATA_PATH="/cpfs01/shared/llm_ddd/lishuaibin/verl_dirs/data/dapo-math-17k_1.jsonl" +export ROLLOUT_DATA_PATH="/cpfs01/shared/llm_ddd/lishuaibin/verl_dirs/data/gsm8k_1.jsonl" +export ROLLOUT_TEST_DATA_PATH="/cpfs01/shared/llm_razor/huanghaian/code/refactor_xtuner/gsm8k/test.jsonl" + +# export PYTHONPATH='/cpfs01/shared/llm_razor/caoweihan/projects/lmdeploy':'/cpfs01/shared/llm_ddd/caoweihan/projects/Liger-Kernel/src/':'.':$PYTHONPATH +export PYTHONPATH='/cpfs01/shared/llm_razor/duanyanhui/workspace/lmdeploy/lmdeploy':'/cpfs01/shared/llm_ddd/caoweihan/projects/Liger-Kernel/src/':'.':$PYTHONPATH + +export UVICORN_LOG_LEVEL="CRITICAl" +export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' + +export PYTHONPATH='/cpfs01/user/lishuaibin/projects/202509/xtuner_github/_push/xtuner_github/':$PYTHONPATH + +# OUTPUT_DIR='work_dirs/dense_8b_gsm8k_grpo_fix_shuaibin' +OUTPUT_DIR='work_dirs/debug_7B_lmdeploy-yanhui_vllm' +if [ ! -d "$OUTPUT_DIR" ]; then + mkdir -p "$OUTPUT_DIR" +fi + +export XTUNER_USE_LMDEPLOY=1 +python ci/scripts/test_lmdeploy_vllm.py \ + --work-dir "$OUTPUT_DIR" \ + --global-batch-size 1 \ + --top-k 0 \ + --top-p 1.0 \ + --temperature 1.0 \ + --prompt-repeat-k 64 \ + --max-prompt-length 2048 \ + --max-response-length 8192 \ + 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" + +# sleep 2 + +# export XTUNER_USE_VLLM=1 +# python ci/scripts/test_lmdeploy_vllm.py \ +# --work-dir "$OUTPUT_DIR" \ +# --global-batch-size 1 \ +# --top-k -1 \ +# --top-p 1.0 \ +# --temperature 1.0 \ +# --prompt-repeat-k 64 \ +# --max-prompt-length 2048 \ +# --max-response-length 8192 \ +# --vllm \ +# 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" + +# sleep 2 \ No newline at end of file diff --git a/ci/scripts/test_sft_dense_trainer.py b/ci/scripts/test_sft_dense_trainer.py index 4babd48fd..677c135f0 100644 --- a/ci/scripts/test_sft_dense_trainer.py +++ b/ci/scripts/test_sft_dense_trainer.py @@ -19,6 +19,7 @@ from xtuner.v1.utils.compile import maybe_compile import argparse from xtuner.v1.loss import CELossConfig +from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig QWEN3_PATH = os.environ["QWEN3_PATH"] ALPACA_PATH = os.environ["ALPACA_PATH"] @@ -217,7 +218,7 @@ def main(): os.environ["DG_CACHE_DIR"] = f"/tmp/.adaptive_gemm-{os.getenv('RANK', '0')}" dense_cfgs = [ - (Qwen3Dense8BConfig(), "tp1"), + (Qwen2Dense7BConfig(), "dense"), ] for dense_cfg, name in dense_cfgs: optim_cfg = AdamWConfig(lr=6e-05) diff --git a/ci/scripts/test_sft_trainer.py b/ci/scripts/test_sft_trainer.py index 23ed6646e..181752d0a 100644 --- a/ci/scripts/test_sft_trainer.py +++ b/ci/scripts/test_sft_trainer.py @@ -20,6 +20,7 @@ from xtuner.v1.utils.device import get_device from xtuner.v1.loss import CELossConfig import argparse +from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig @@ -218,8 +219,9 @@ def main(): os.environ["DG_CACHE_DIR"] = f"/tmp/.adaptive_gemm-{os.getenv('RANK', '0')}" moe_cfgs = [ - (Qwen3MoE30BA3Config(balancing_loss_cfg=BalancingLossConfig()), "ep1"), - (Qwen3MoE30BA3Config(ep_size=8, dispatcher="all2all"), "ep8"), + # (Qwen3MoE30BA3Config(balancing_loss_cfg=BalancingLossConfig()), "ep1"), + # (Qwen3MoE30BA3Config(ep_size=8, dispatcher="all2all"), "ep8"), + (Qwen2Dense7BConfig(), "dense") ] for moe_cfg, name in moe_cfgs: optim_cfg = AdamWConfig(lr=6e-05) @@ -227,7 +229,7 @@ def main(): fsdp_cfg = FSDPConfig( torch_compile=False, #get_device() == "cuda", cpu_offload=False, - ep_size=moe_cfg.ep_size, + ep_size=1, # hsdp_sharding_size=4, ) dataset_config = [ diff --git a/tests/model/test_qwen3_dense.py b/tests/model/test_qwen3_dense.py index 50b654da4..74496e7af 100644 --- a/tests/model/test_qwen3_dense.py +++ b/tests/model/test_qwen3_dense.py @@ -16,6 +16,7 @@ from xtuner.v1.config import FSDPConfig from xtuner.v1.utils.compile import maybe_compile from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem +from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig # Qwen3 8B QWEN3_PATH = os.environ["QWEN3_PATH"] @@ -26,7 +27,7 @@ class TestQwen3Dense(DistributedTestBase): "device,tp_size,compile,tol,loss_class", [ ("cuda", 1, False, 1e-2, "cross_entropy"), - ("cuda", 1, False, 1e-2, "chunk_cross_entropy"), + # ("cuda", 1, False, 1e-2, "chunk_cross_entropy"), ], ) def test_qwen3_dense_run(self, device, tp_size, compile, tol, loss_class): @@ -52,7 +53,8 @@ def test_qwen3_dense_run(self, device, tp_size, compile, tol, loss_class): torch.cuda.empty_cache() with torch.device("meta"): - cfg = Qwen3Dense8BConfig() + # cfg = Qwen3Dense8BConfig() + cfg = Qwen2Dense7BConfig() qwen_model = cfg.build().to(torch.bfloat16) shift_input_ids = input_ids[:, :-1] @@ -108,7 +110,7 @@ def test_fsdp_accuracy(self, device, tp_size): torch.cuda.empty_cache() with torch.device("meta"): - cfg = Qwen3Dense8BConfig() + cfg = Qwen2Dense7BConfig() qwen_model = cfg.build().to(torch.bfloat16) fsdp_config = FSDPConfig( @@ -139,6 +141,7 @@ def test_fsdp_accuracy(self, device, tp_size): loss_ctx=loss_ctx, ) loss = output["loss"] + dist.breakpoint() self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=1e-2, rtol=1e-2)) @parametrize.parametrize( diff --git a/xtuner/v1/model/dense/qwen2.py b/xtuner/v1/model/dense/qwen2.py new file mode 100755 index 000000000..a4a9097c2 --- /dev/null +++ b/xtuner/v1/model/dense/qwen2.py @@ -0,0 +1,54 @@ +import re + +from xtuner.v1.model.base import TransformerConfig +from xtuner.v1.module.attention import MHAConfig + +from .dense import Dense + + +class Qwen2Dense(Dense): + def to_hf_key_list(self, key: str) -> list[str]: + if self.config.tie_word_embeddings and "lm_head" in key: + key = key.replace("lm_head", "embed_tokens") + + if "layers" in key or "embed_tokens" in key: + key = "model." + key + + if "layers" in key: + key = re.sub(r"layers\.(\d+)\.(experts|gate)", r"layers.\1.mlp.\2", key) + + if key.startswith("norm."): + return [key.replace("norm.", "model.norm.")] + else: + return [key] + + +class Qwen2DenseConfig(TransformerConfig): + use_sliding_window: bool = False + + def build(self) -> Qwen2Dense: + return Qwen2Dense(self) + + +# TODO: Unify the config name style +class Qwen2Dense7BConfig(Qwen2DenseConfig): + vocab_size: int = 152064 + max_position_embeddings: int = 32768 + pad_token_id: int = 151645 # eos_id + eos_token_id: int = 151645 + bos_token_id: int = 151643 + num_hidden_layers: int = 28 + hidden_size: int = 3584 + intermediate_size: int = 18944 + rms_norm_eps: float = 1e-06 + rope_theta: float = 10000 + hidden_act: str = "silu" + attention: MHAConfig = MHAConfig( + num_attention_heads=28, + num_key_value_heads=4, + head_dim=128, + qk_norm=False, + qkv_bias=True, + ) + # sliding_window= 4096 + tie_word_embeddings: bool = False diff --git a/xtuner/v1/ray/judger/dapo_math.py b/xtuner/v1/ray/judger/dapo_math.py new file mode 100755 index 000000000..545038fce --- /dev/null +++ b/xtuner/v1/ray/judger/dapo_math.py @@ -0,0 +1,407 @@ +import re +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from .native import NativeJudger + + +# _SOLUTION_CLIP_CHARS = 300 + + +# def extract_solution(solution_str, method="strict"): +# assert method in ["strict", "flexible"] + +# # Optimization: Regular expression matching on very long strings can be slow. +# # For math problems, the final answer is usually at the end. +# # We only match on the last 300 characters, which is a safe approximation for 300 tokens. +# if len(solution_str) > _SOLUTION_CLIP_CHARS: +# solution_str = solution_str[-_SOLUTION_CLIP_CHARS:] + +# if method == "strict": +# # this also tests the formatting of the model +# solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str) +# if len(solutions) == 0: +# final_answer = None +# else: +# # take the last solution +# final_answer = solutions[-1].replace(",", "").replace("$", "") +# elif method == "flexible": +# answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) +# final_answer = None +# if len(answer) == 0: +# # no reward is there is no answer +# pass +# else: +# invalid_str = ["", "."] +# # find the last number that is not '.' +# for final_answer in reversed(answer): +# if final_answer not in invalid_str: +# break +# return final_answer + + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + + +def last_boxed_only_string(string: str) -> Optional[str]: + """Extract the last LaTeX boxed expression from a string. + + Args: + string: Input string containing LaTeX code + + Returns: + The last boxed expression or None if not found + """ + idx = string.rfind("\\boxed{") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None + + +def remove_boxed(s: str) -> str: + """Remove the LaTeX boxed command from a string. + + Args: + s: String with format "\\boxed{content}" + + Returns: + The content inside the boxed command + """ + left = "\\boxed{" + assert s[: len(left)] == left, f"box error: {s}" + assert s[-1] == "}", f"box error: {s}" + return s[len(left) : -1] + + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +def is_correct_minerva( + solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" +) -> tuple[bool, str]: + """Check if the solution is correct according to Minerva criteria. + + Args: + solution_str: The solution string to check + gt: The ground truth answer + gt_need_extract: Whether the ground truth needs extraction + answer_pattern: Regex pattern to extract the answer + + Returns: + Tuple of (is_correct, normalized_prediction) + """ + # Extract answer from solution + match = re.findall(answer_pattern, solution_str) + extracted_answer = match[-1] if match else "[INVALID]" + pred = normalize_final_answer(extracted_answer) + + # Process ground truth + # if gt_need_extract: + # gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt))) + # else: + assert not gt_need_extract + gt = normalize_final_answer(gt) + + return (pred == gt), pred + + +def is_correct_strict_box( + pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None +) -> tuple[int, Optional[str]]: + """Check if the prediction is correct using strict boxed answer criteria. + + Args: + pred: The prediction string + gt: The ground truth answer + pause_tokens_index: Indices of pause tokens + + Returns: + Tuple of (score, extracted_prediction) + """ + # Extract the relevant part of the prediction + if pause_tokens_index is not None: + assert len(pause_tokens_index) == 4 + pred = pred[pause_tokens_index[-1] - 100 :] + else: + pred = pred[-100:] + + # Extract and check the boxed answer + boxed_pred = last_boxed_only_string(pred) + extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None + # print("==========", extracted_pred, gt) + + return 1 if (extracted_pred == gt) else -1, extracted_pred + + +def verify( + solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None +) -> tuple[bool, str | None]: + """Verify if the solution is correct. + + Args: + solution_str: The solution string to verify + answer: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + True if the solution is correct, False otherwise + """ + if strict_box_verify: + correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index) + return correct == 1, pred + + correct, pred = is_correct_minerva(solution_str, answer) + return correct, pred + + +def compute_score( + solution_str: str, + ground_truth: str, + strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None, +) -> float: + """Compute the reward score for a solution. + + Args: + solution_str: The solution string + ground_truth: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + Reward score (1.0 for correct, -1.0 for incorrect) + """ + # Limit solution length for efficiency + solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + + # Verify the solution + correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + + reward = 1.0 if correct else -1.0 + # acc = correct + + return reward + # return { + # "score": reward, + # "acc": acc, + # "pred": pred, + # } + + +def compute_reward(response, label, extra_info): + predict_str = response + # ground_truth = label + + reward = compute_score(response, label) + overlong_reward = 0 + if extra_info.get("enable_overlong_buffer", None): + overlong_buffer_len = extra_info["overlong_buffer_len"] + expected_len = extra_info["max_response_len"] - overlong_buffer_len + valid_response_length = len( + extra_info["tokenizer"](predict_str, return_tensors="pt")["input_ids"].flatten().tolist() + ) + exceed_len = valid_response_length - expected_len + overlong_penalty_factor = extra_info["overlong_penalty_factor"] + overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) + reward += overlong_reward + return reward + + +class DapoMathJudgerConfig(BaseModel): + extra_info: dict = Field(default={"score": 1, "format_score": 0}) + enable_overlong_buffer: bool + max_response_len: Optional[int] = None + overlong_buffer_len: Optional[int] = None + overlong_penalty_factor: Optional[float] = None + tokenizer: Any = None + + def __init__( + self, + enable_overlong_buffer: bool, + max_response_len: Optional[int], + overlong_buffer_len: Optional[int], + overlong_penalty_factor: Optional[float], + tokenizer: Any, + ): + # 初始化基类 + super().__init__( + enable_overlong_buffer=enable_overlong_buffer, + max_response_len=max_response_len, + overlong_buffer_len=overlong_buffer_len, + overlong_penalty_factor=overlong_penalty_factor, + tokenizer=tokenizer, + ) + + # 根据条件更新 extra_info + if enable_overlong_buffer: + assert max_response_len is not None + assert overlong_buffer_len is not None + assert overlong_penalty_factor is not None + assert tokenizer is not None + self.extra_info.update( + { + "enable_overlong_buffer": enable_overlong_buffer, + "max_response_len": max_response_len, + "overlong_buffer_len": overlong_buffer_len, + "overlong_penalty_factor": overlong_penalty_factor, + "tokenizer": tokenizer, + } + ) + + def build(self): + return NativeJudger(reward_func=compute_reward, extra_info=self.extra_info) + + +if __name__ == "__main__": + import json + + data = [] + with open( + "/cpfs01/user/lishuaibin/projects/202508/xtuner/work_dirs/dapo_math_qwen25-7B/20250828103011/t_1.jsonl", + encoding="utf-8", + ) as file: + for line in file: + line = line.strip() + if line: # Skip empty lines + try: + obj = json.loads(line) + data.append(obj) + except json.JSONDecodeError as e: + print(f"Error parsing line: {e}") + _data = data[0] + # print(_data) + prompt = _data["prompt"] + responses = _data["response"] + label = _data["label"] + for res in responses: + reward = compute_reward(res, label, {}) + # reward = compute_score(res, label, True) + + print(reward) diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index e12ee7954..935cde74c 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -1,4 +1,5 @@ import math +import random from typing import Literal, TypedDict import ray @@ -117,11 +118,12 @@ def _grouped_by_max_length(self, packed_data_batches): # 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。 return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True) - def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int): - packed_data_batches = self._packing(data_batches, pack_max_length) + def _pack_and_pad(self, data_batches, pack_max_length: int): + data_batches_flatten = sum(data_batches, []) # type: ignore + random.shuffle(data_batches_flatten) + packed_data_batches = self._packing(data_batches_flatten, pack_max_length) packed_data_batches = self._grouped_by_max_length(packed_data_batches) - # todo: support round up num_packed_data_batches = len(packed_data_batches) data_replicate_size = ray.get(self.workers[0].get_data_replicate_size.remote()) # type: ignore[attr-defined] dp_size = len(self.workers) // data_replicate_size @@ -129,13 +131,16 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: if pad_num > 0: # Reduce the attn calculation time by using multiple short sequence packs pad_tokens = tuple( - torch.zeros(1, 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu") + torch.zeros(1, 1024, dtype=data_batches_flatten[0]["seq_ctx"].input_ids.dtype, device="cpu") for _ in range(pack_max_length // 1024) ) if pack_max_length % 1024 > 0: pad_tokens = pad_tokens + ( torch.zeros( - 1, pack_max_length % 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu" + 1, + pack_max_length % 1024, + dtype=data_batches_flatten[0]["seq_ctx"].input_ids.dtype, + device="cpu", ), ) pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu") # type: ignore @@ -159,18 +164,35 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: } pad_data_samples = [pad_data for _ in range(pad_num)] packed_data_batches = packed_data_batches + pad_data_samples + return packed_data_batches - print(f"len(packed_data_batches): {len(packed_data_batches)}") + def fit(self, data_batches: list[list[ColateItem]], pack_max_length: int, optimizer_steps: int, rollout_idx: int): + optimizer_steps = min(optimizer_steps, len(data_batches)) + n_groups_per_step = math.ceil(len(data_batches) / optimizer_steps) + packed_data_batches_all_steps: list[list[ColateItem]] = [] + for i in range(optimizer_steps): + packed_data_batches_all_steps.append( + self._pack_and_pad( + data_batches[i * n_groups_per_step : (i + 1) * n_groups_per_step], + pack_max_length, + ) + ) handles = [] + data_replicate_size = ray.get(self.workers[0].get_data_replicate_size.remote()) # type: ignore[attr-defined] + dp_size = len(self.workers) // data_replicate_size for worker_idx, worker in enumerate(self.workers): + packed_data_batches_all_steps_cur = [ + data[(worker_idx // data_replicate_size) :: dp_size] for data in packed_data_batches_all_steps + ] handles.append( worker.fit.remote( # type: ignore[attr-defined] - data_batches=packed_data_batches[(worker_idx // data_replicate_size) :: dp_size], + data_batches=packed_data_batches_all_steps_cur, rollout_idx=rollout_idx, ) ) ray.get(handles) + return def offload(self, target: Literal["model", "optimizer", "all"] = "all"): if target == "model": diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index d345583ce..e70465a70 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -1,4 +1,3 @@ -import math import os import time from pathlib import Path @@ -239,17 +238,16 @@ def compute_ref_logprobs( self._ref_model.to_device("cpu") return loss_ctx_input_list - def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): - num_batches = len(data_batches) - iters_per_step = math.ceil(num_batches / self._optimizer_steps) - if num_batches < self._optimizer_steps: - logger.info( - f"Optimizer only step once because num_batches {num_batches} < optimizer_steps {self._optimizer_steps}." - ) + def fit(self, data_batches: list[list[WorkerInputItem]], rollout_idx: int): + iters_per_step = [len(data) for data in data_batches] + accum_iters = [0] + for iters in iters_per_step: + accum_iters.append(accum_iters[-1] + iters) + data_batches_flatten: list[WorkerInputItem] = sum(data_batches, []) # type: ignore seq_ctx_list: list[SequenceContext] = [] loss_ctx_input_list: list[RLLossContextInputItem] = [] - for data in data_batches: + for data in data_batches_flatten: seq_ctx = data["seq_ctx"].to(DEVICE) loss_ctx_input = RLLossContextInputItem( shifted_labels=data["shifted_labels"], @@ -261,7 +259,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): seq_ctx_list.append(seq_ctx) loss_ctx_input_list.append(loss_ctx_input) - del data_batches + del data_batches, data_batches_flatten rank_grad_tokens: torch.Tensor | None = None for loss_ctx_input in loss_ctx_input_list: @@ -303,9 +301,12 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): avg_kl_div = kl_div_sum / global_grad_tokens if global_grad_tokens > 0 else 0 logger.info(f"Rollout {rollout_idx}: avg KL divergence: {avg_kl_div:.4f}") - for i in range(0, len(seq_ctx_list), iters_per_step): - batches_seq_ctx = seq_ctx_list[i : i + iters_per_step] - batches_loss_ctx_input = loss_ctx_input_list[i : i + iters_per_step] + for step_idx in range(self._optimizer_steps): + batches_seq_ctx = seq_ctx_list[accum_iters[step_idx] : accum_iters[step_idx + 1]] + batches_loss_ctx_input = loss_ctx_input_list[accum_iters[step_idx] : accum_iters[step_idx + 1]] + # for i in range(0, len(seq_ctx_list), iters_per_step): + # batches_seq_ctx = seq_ctx_list[i : i + iters_per_step] + # batches_loss_ctx_input = loss_ctx_input_list[i : i + iters_per_step] loss_cfg = self.config.loss_cfg LossContext = loss_cfg.loss_ctx_cls @@ -336,7 +337,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): f"{key}={value:.4f}" if isinstance(value, float) else f"{key}={value}" for key, value in log_info.items() ) - log_str = f"Rollout {rollout_idx} Step {i}: " + log_str + log_str = f"Rollout {rollout_idx} Step {step_idx}: " + log_str logger.info(log_str) def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): diff --git a/xtuner/v1/rl/loss_fn.py b/xtuner/v1/rl/loss_fn.py index 6efe90d9a..7df701811 100644 --- a/xtuner/v1/rl/loss_fn.py +++ b/xtuner/v1/rl/loss_fn.py @@ -54,12 +54,26 @@ def pg_loss_fn( check_config(["cliprange_low", "cliprange_high"], policy_loss_cfg) cliprange_low = policy_loss_cfg["cliprange_low"] cliprange_high = policy_loss_cfg["cliprange_high"] - ratio = (logprobs - old_logprobs.detach()).exp() + clip_ratio_c = policy_loss_cfg.get("clip_ratio_c", 10.0) + advantages = advantages.to(logprobs.dtype) - loss1 = -ratio * advantages - loss2 = -ratio.clamp(1 - cliprange_low, 1 + cliprange_high) * advantages - loss_max = torch.max(loss1, loss2) - loss = (loss_max * loss_weights.to(loss_max.dtype)).sum() + + negative_approx_kl = logprobs - old_logprobs + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + print(f"========[RATIO] max{ratio.max().item()}, min{ratio.min().item()}, mean{ratio.mean().item()}") + + pg_losses1 = -advantages * ratio + pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + + pg_losses3 = -advantages * clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + + loss = (pg_losses * loss_weights.to(pg_losses.dtype)).sum() return loss diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 08d7059b9..9a6ff997f 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -295,11 +295,19 @@ def fit(self): self.logger.info(f"rollout_idx {rollout_idx} finished, saved trajectories to {trajectory_save_path}") ray.get(self._train_controller.onload.remote(target="all")) self.logger.info("Training controller loaded") - data_batches = self._prepare_train_data(data_groups, self._train_worker_cfg.pack_max_length) - self.logger.info(f"Prepared {len(data_batches)} training data batches") + data_batches, data_info = self._prepare_train_data(data_groups, self._train_worker_cfg.pack_max_length) + self.logger.info(f"Prepared {len(data_batches)} prompts and {len(data_batches) * len(data_batches[0])} responses training data batches") + self.logger.info(f"DataInfo {data_info}") + + # save_dir = f"/cpfs01/shared/llm_razor/lishuaibin/xtuner_v1_outputs/lmdeploy1/train_data/global_step{rollout_idx}.pt" + # torch.save(data_batches, save_dir) + ray.get( self._train_controller.fit.remote( - data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx + data_batches, + pack_max_length=self._train_worker_cfg.pack_max_length, + optimizer_steps=self._train_worker_cfg.optimizer_steps, + rollout_idx=rollout_idx ) ) ray.get(self._train_controller.offload.remote(target="optimizer")) @@ -318,6 +326,12 @@ def fit(self): # TODO: advantage 是在 DataFlow 里算好,还是在 train controller 里算? # 因为可能有根据 advantage 来判断数据能否进 rl 训练的需求。暂时先放在这 def _prepare_train_data(self, data_groups, pack_max_length): + import numpy as np + rewards_list = [] + advantages_list = [] + prompt_len_list = [] + response_len_list = [] + data_batches = [] for group in data_groups: prompt = self.tokenizer.apply_chat_template( @@ -325,29 +339,59 @@ def _prepare_train_data(self, data_groups, pack_max_length): ) prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].flatten().tolist() rewards = [data["reward"] for data in group] + rewards_list.extend(rewards) rewards = torch.tensor(rewards, dtype=torch.float32) + # self.logger.info(f"=================={rewards}, {rewards.mean(0)}, {rewards.std(0)}") advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8) + # advantages_list.extend(advantages.tolist()) prompt_repeat_k = len(group) + data_batches_group = [] for i in range(prompt_repeat_k): item = group[i]["response_str"] response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() input_ids = prompt_ids + response_ids + prompt_len_list.append(len(prompt_ids)) + response_len_list.append(len(response_ids)) + + advantages_list.extend([advantages[i]]*len(response_ids)) + shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + [-100] if len(input_ids) > pack_max_length: input_ids = input_ids[:pack_max_length] shifted_labels = shifted_labels[:pack_max_length] input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) - data_batches.append( + data_batches_group.append( dict( seq_ctx=SequenceContext.from_input_ids((input_ids,), device="cpu"), shifted_labels=shifted_labels, advantage=advantages[i].item(), ) ) - random.shuffle(data_batches) - return data_batches + data_batches.append(data_batches_group) + + advantages_list = np.array(advantages_list) + self.logger.info(f"============ advantages: {advantages_list.shape}") + info_dict = { + "batch_size": len(rewards_list), + "rewards/mean": np.mean(rewards_list), + "rewards/min": np.min(rewards_list), + "rewards/max": np.max(rewards_list), + "advantages/mean": np.mean(advantages_list), + "advantages/min": np.min(advantages_list), + "advantages/max": np.max(advantages_list), + "response_len/mean": np.mean(response_len_list), + "response_len/min": np.min(response_len_list), + "response_len/max": np.max(response_len_list), + "response_len/std": np.std(response_len_list), + # "response_len": response_len_list, + "prompt_len/mean": np.mean(prompt_len_list), + "prompt_len/min": np.min(prompt_len_list), + "prompt_len/max": np.max(prompt_len_list), + } + random.shuffle(data_batches) # shuffle in groups + return data_batches, info_dict def _save_trajectories(self, data_groups, save_path): with open(save_path, "w") as f: @@ -357,14 +401,14 @@ def _save_trajectories(self, data_groups, save_path): for data in group: response_list.append(data["response_str"]) reward_list.append(data["reward"]) - item = { - "messages": group[0]["messages"], - "response": response_list, - "label": group[0]["reward_model"]["ground_truth"], - "reward": reward_list, - } - json.dump(item, f) - f.write("\n") + item = { + "messages": group[0]["messages"], + "response": data["response_str"], + "label": group[0]["reward_model"]["ground_truth"], + "reward": data["reward"], + } + json.dump(item, f, ensure_ascii=False) + f.write("\n") def _load_trajectories(self, save_path): data_groups = []