From 1ee63ee0207e80ce84266adc9ff53c02a2644ca3 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 17 Sep 2025 14:56:56 +0800 Subject: [PATCH 01/22] add dapo test --- ci/scripts/test_dapo.sh | 35 +++ ci/scripts/test_dapo_trainer.py | 206 ++++++++++++++++ xtuner/v1/model/dense/qwen2.py | 53 ++++ xtuner/v1/ray/evaluator.py | 2 +- xtuner/v1/ray/judger/controller.py | 5 +- xtuner/v1/ray/judger/dapo_math.py | 373 +++++++++++++++++++++++++++++ xtuner/v1/train/rl_trainer.py | 10 +- 7 files changed, 676 insertions(+), 8 deletions(-) create mode 100644 ci/scripts/test_dapo.sh create mode 100644 ci/scripts/test_dapo_trainer.py create mode 100644 xtuner/v1/model/dense/qwen2.py create mode 100644 xtuner/v1/ray/judger/dapo_math.py diff --git a/ci/scripts/test_dapo.sh b/ci/scripts/test_dapo.sh new file mode 100644 index 000000000..94f34da0f --- /dev/null +++ b/ci/scripts/test_dapo.sh @@ -0,0 +1,35 @@ +set -ex + +export XTUNER_USE_LMDEPLOY=1 +export XTUNER_USE_FA3=1 +export PYTHONPATH=/cpfs01/shared/llm_razor/huanghaian/code/lmdeploy/:$PYTHONPATH +export UVICORN_LOG_LEVEL="CRITICAl" +export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' + +# OUTPUT_DIR='work_dirs/dense_8b_gsm8k_grpo_fix_shuaibin' +OUTPUT_DIR='work_dirs/dapo_math_debug_8B' +OUTPUT_DIR='work_dirs/dapo_math_7B_newlmdeploy_nogroup' +if [ ! -d "$OUTPUT_DIR" ]; then + mkdir -p "$OUTPUT_DIR" +fi + +#export ROLLOUT_MODEL_PATH="/cpfs01/shared/llm_razor/huanghaian/code/verl/DAPO/DAPO-Qwen2.5-7b-MATH-0527a1/global_step_100/actor/huggingface" +export ROLLOUT_MODEL_PATH='/cpfs01/shared/llm_ddd/lishuaibin/ckpt/Qwen/Qwen2.5-Math-7B' +export ROLLOUT_DATA_PATH="/cpfs01/shared/llm_razor/lishuaibin/math_dapo_data/dapo-math-17k.jsonl" + +python ci/scripts/test_dapo_trainer.py \ + --total-epochs 1 \ + --work-dir "$OUTPUT_DIR" \ + --num-workers 8 \ + --gpus-per-node 8 \ + --rollout-global-batch-size 512 \ + --train-optimizer-steps 16 \ + --max-concurrent 64 \ + --prompt-repeat-k 16 \ + --pack-max-length 32768 \ + --max-prompt-length 2048 \ + --max-response-length 8192 \ + --optimizer-disable-foreach \ + --enable-evaluate \ + --eval-data-path /cpfs01/shared/llm_razor/lishuaibin/math_dapo_data/aime-2024.jsonl \ + 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" diff --git a/ci/scripts/test_dapo_trainer.py b/ci/scripts/test_dapo_trainer.py new file mode 100644 index 000000000..77ba1b6ea --- /dev/null +++ b/ci/scripts/test_dapo_trainer.py @@ -0,0 +1,206 @@ +import os +import re +from pathlib import Path +import ray +import argparse + +import matplotlib.pyplot as plt +import numpy as np +import torch.distributed as dist +from transformers import AutoTokenizer + +from xtuner.v1.config import ( + AdamWConfig, + FSDPConfig, + LRConfig, +) +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig +from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig +from xtuner.v1.ray.accelerator import AcceleratorResourcesConfig +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig +from xtuner.v1.ray.rollout import SampleParams +from xtuner.v1.ray.evaluator import EvaluatorConfig +from xtuner.v1.datasets import RLTextTokenizeFnConfig +from xtuner.v1.config import ( + AdamWConfig, + FSDPConfig, + LRConfig, +) +from xtuner.v1.ray.judger.controller import JudgerConfig +from xtuner.v1.rl.base import WorkerConfig +from xtuner.v1.rl.grpo import GRPOLossConfig +# from xtuner.v1.rl.grpo import GRPOLossConfig, WorkerConfig +# from xtuner.v1.rl.grpo.config import WorkerConfig, LossConfig +# from xtuner.v1.rl.grpo.trainer import Trainer +from xtuner.v1.train.rl_trainer import RLTrainer + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] +os.environ['XTUNER_USE_FA3'] = "1" + + +def parse_args(): + parser = argparse.ArgumentParser(description="VLLM Rollout Test Script") + parser.add_argument("--total-epochs", type=int) + parser.add_argument("--work-dir", type=str, default="work_dir") + parser.add_argument("--model-path", type=str, default=MODEL_PATH) + parser.add_argument("--data-path", type=str, default=TRAIN_DATA_PATH) + parser.add_argument("--eval-data-path", type=str, default=TEST_DATA_PATH) + parser.add_argument("--num-workers", type=int, default=8) + parser.add_argument("--gpus-per-node", type=int, default=8) + parser.add_argument("--rollout-global-batch-size", type=int, default=128) + parser.add_argument("--train-optimizer-steps", type=int, default=1) + parser.add_argument("--max-concurrent", type=int, default=8) + parser.add_argument("--prompt-repeat-k", type=int, default=8) + parser.add_argument("--pack-max-length", type=int, default=8192) + parser.add_argument("--max-prompt-length", type=int, default=512) + parser.add_argument("--max-response-length", type=int, default=1024) + parser.add_argument("--optimizer-disable-foreach", action="store_true") # save memory usage during opt.step() + parser.add_argument("--policy-loss-type", type=str, default="vanilla") + parser.add_argument("--enable-evaluate", action="store_true") + parser.add_argument("--evaluate-step", type=int, default=1) + parser.add_argument("--evaluate-ratio", type=float, default=1) + parser.add_argument("--ray-cluster-url", type=str, default="") + return parser.parse_args() + + +def main(args): + if args.ray_cluster_url == "": + ray.init(num_cpus=128, ignore_reinit_error=True) + else: + ray.init(address=args.ray_cluster_url, ignore_reinit_error=True) + load_from = args.model_path + resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_accelerators_per_worker=1, + num_cpus_per_worker=12, + num_workers=args.num_workers, + cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB + ) + rollout_config = RolloutConfig( + env="test_env", + model_path=args.model_path, + model_name=os.path.basename(args.model_path).lower(), + tokenizer_path=args.model_path, + rollout_cross_node_comm=False, + tensor_parallel_size=2, + expert_parallel_size=1, + gpus_per_node=args.gpus_per_node, # gpu: 8, npu: 16 + dtype="bfloat16", + skip_load_weights=False, + ) + dataflow_config = DataFlowConfig( + env="test", + max_concurrent=args.max_concurrent, + prompt_repeat_k=args.prompt_repeat_k, + global_batch_size=args.rollout_global_batch_size, + sample_params=SampleParams( + max_tokens=args.max_response_length, + # ###### greedy + # top_k=20, + # # temperature=1e-6, + ########## + top_k=0, + top_p=1.0, + temperature=1.0, + + min_tokens=0, + # stop_token_ids= [], + # logprobs= 0, + # skip_special_tokens= True, + do_sample=True, + ), + ) + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig + dapomath_judger_config = DapoMathJudgerConfig(True, args.max_response_length, 4096, 1.0, tokenizer) + judger_cfg = JudgerConfig( + reward_judger_configs={"math_dapo": dapomath_judger_config} + ) + train_dataset_cfg = [ + { + "dataset": DatasetConfig(name="dapo_math", + anno_path=args.data_path, + sample_ratio=1.0), + "tokenize_fn": RLTextTokenizeFnConfig(max_length=args.max_prompt_length), + }, + ] + eval_dataset_cfg = [ + { + "dataset": DatasetConfig(name="dapo_math", + anno_path=args.eval_data_path, + sample_ratio=1.0), + "tokenize_fn": RLTextTokenizeFnConfig(max_length=args.max_prompt_length), + }, + ] + dataloader_cfg = DataloaderConfig( + pack_max_length=args.pack_max_length, + collator='fake_collator', + pack_level='none', + ) + # tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + evaluator_cfg = EvaluatorConfig( + dataset_cfg=eval_dataset_cfg, + tokenizer=tokenizer, + max_concurrent=args.max_concurrent, + eval_sample_ratio=args.evaluate_ratio, + evaluate_step=args.evaluate_step, + compute_metric_func=None + ) + replay_buffer_cfg = ReplayBufferConfig( + dataset_cfg=train_dataset_cfg, + dataloader_cfg=dataloader_cfg, + tokenizer=tokenizer, + postprocessor=None + ) + train_worker_cfg: WorkerConfig = WorkerConfig( + # model_cfg=Qwen3Dense8BConfig(), + model_cfg=Qwen2Dense7BConfig(), + optim_cfg=AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, + foreach=False if args.optimizer_disable_foreach else None), + loss_cfg=GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=args.policy_loss_type, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode="chunk", + chunk_size=512), + lr_cfg=LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6), + fsdp_cfg=FSDPConfig( + torch_compile=False, + cpu_offload=False, + ep_size=1, + ), + load_from=args.model_path, + sp_size=1, + optimizer_steps=args.train_optimizer_steps, + pack_max_length=args.pack_max_length, + ) + trainer = RLTrainer( + load_from=load_from, + resources=resources, + rollout_config=rollout_config, + dataflow_config=dataflow_config, + judger_config=judger_cfg, + replay_buffer_config=replay_buffer_cfg, + evaluator_config=evaluator_cfg, + train_worker_cfg=train_worker_cfg, + tokenizer_path=args.model_path, + work_dir=args.work_dir, + total_epochs=args.total_epochs, + enable_evaluate=args.enable_evaluate + ) + trainer.fit() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/xtuner/v1/model/dense/qwen2.py b/xtuner/v1/model/dense/qwen2.py new file mode 100644 index 000000000..ecc947afb --- /dev/null +++ b/xtuner/v1/model/dense/qwen2.py @@ -0,0 +1,53 @@ +import re + +from xtuner.v1.model.base import TransformerConfig +from xtuner.v1.module.attention import MHAConfig + +from .dense import Dense + + +class Qwen2Dense(Dense): + def to_hf_key_list(self, key: str) -> list[str]: + if self.config.tie_word_embeddings and "lm_head" in key: + key = key.replace("lm_head", "embed_tokens") + + if "layers" in key or "embed_tokens" in key: + key = "model." + key + + if "layers" in key: + key = re.sub(r"layers\.(\d+)\.(experts|gate)", r"layers.\1.mlp.\2", key) + + if key.startswith("norm."): + return [key.replace("norm.", "model.norm.")] + else: + return [key] + + +class Qwen2DenseConfig(TransformerConfig): + use_sliding_window: bool = False + + def build(self) -> Qwen2Dense: + return Qwen2Dense(self) + + +# TODO: Unify the config name style +class Qwen2Dense7BConfig(Qwen2DenseConfig): + vocab_size: int = 152064 + max_position_embeddings: int = 32768 + pad_token_id: int = 151643 # eos_id + eos_token_id: int = 151643 # eos_id + num_hidden_layers: int = 28 + hidden_size: int = 3584 + intermediate_size: int = 18944 + rms_norm_eps: float = 1e-06 + rope_theta: float = 10000 + hidden_act: str = "silu" + attention: MHAConfig = MHAConfig( + num_attention_heads=28, + num_key_value_heads=4, + head_dim=128, + qk_norm=False, + qkv_bias=True, + ) + # sliding_window= 4096 + tie_word_embeddings: bool = False \ No newline at end of file diff --git a/xtuner/v1/ray/evaluator.py b/xtuner/v1/ray/evaluator.py index 1130ef49d..180486e26 100644 --- a/xtuner/v1/ray/evaluator.py +++ b/xtuner/v1/ray/evaluator.py @@ -140,7 +140,7 @@ def default_compute_metric(self, samples): Returns: dict: A dictionary containing the accuracy score. """ - return {"accuracy": sum(s["reward"] > 0 for s in samples) / len(samples)} + return {"accuracy": sum(s["acc"] > 0 for s in samples) / len(samples)} async def eval_worker_task(self, sample: RLTextDataItem): """A single worker task to evaluate one sample. diff --git a/xtuner/v1/ray/judger/controller.py b/xtuner/v1/ray/judger/controller.py index 074e007cd..ab27a9cee 100644 --- a/xtuner/v1/ray/judger/controller.py +++ b/xtuner/v1/ray/judger/controller.py @@ -163,14 +163,17 @@ async def run( num_samples = len(group_data_item) final_rewards = [0.0] * num_samples + acc_list = [] for i in range(num_samples): for name, scores in rewards_by_name.items(): weight = data_source.get(name, 1.0) - final_rewards[i] += scores[i] * weight + final_rewards[i] += scores[i]['score'] * weight + acc_list.append(scores[i]['acc']) assert len(final_rewards) == num_samples for i, item in enumerate(group_data_item): item["reward"] = final_rewards[i] + item["acc"] = acc_list[i] if not input_list: return group_data_item[0] return group_data_item diff --git a/xtuner/v1/ray/judger/dapo_math.py b/xtuner/v1/ray/judger/dapo_math.py new file mode 100644 index 000000000..785ad9d05 --- /dev/null +++ b/xtuner/v1/ray/judger/dapo_math.py @@ -0,0 +1,373 @@ +import re + +from pydantic import BaseModel, Field +from typing import Any, Optional + +from .native import NativeJudger + +# _SOLUTION_CLIP_CHARS = 300 + + +# def extract_solution(solution_str, method="strict"): +# assert method in ["strict", "flexible"] + +# # Optimization: Regular expression matching on very long strings can be slow. +# # For math problems, the final answer is usually at the end. +# # We only match on the last 300 characters, which is a safe approximation for 300 tokens. +# if len(solution_str) > _SOLUTION_CLIP_CHARS: +# solution_str = solution_str[-_SOLUTION_CLIP_CHARS:] + +# if method == "strict": +# # this also tests the formatting of the model +# solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str) +# if len(solutions) == 0: +# final_answer = None +# else: +# # take the last solution +# final_answer = solutions[-1].replace(",", "").replace("$", "") +# elif method == "flexible": +# answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) +# final_answer = None +# if len(answer) == 0: +# # no reward is there is no answer +# pass +# else: +# invalid_str = ["", "."] +# # find the last number that is not '.' +# for final_answer in reversed(answer): +# if final_answer not in invalid_str: +# break +# return final_answer + + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + +import re +from typing import Optional + + +def last_boxed_only_string(string: str) -> Optional[str]: + """Extract the last LaTeX boxed expression from a string. + + Args: + string: Input string containing LaTeX code + + Returns: + The last boxed expression or None if not found + """ + idx = string.rfind("\\boxed{") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + return string[idx: right_brace_idx + 1] if right_brace_idx is not None else None + + +def remove_boxed(s: str) -> str: + """Remove the LaTeX boxed command from a string. + + Args: + s: String with format "\\boxed{content}" + + Returns: + The content inside the boxed command + """ + left = "\\boxed{" + assert s[: len(left)] == left, f"box error: {s}" + assert s[-1] == "}", f"box error: {s}" + return s[len(left): -1] + + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +def is_correct_minerva( + solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" +) -> tuple[bool, str]: + """Check if the solution is correct according to Minerva criteria. + + Args: + solution_str: The solution string to check + gt: The ground truth answer + gt_need_extract: Whether the ground truth needs extraction + answer_pattern: Regex pattern to extract the answer + + Returns: + Tuple of (is_correct, normalized_prediction) + """ + # Extract answer from solution + match = re.findall(answer_pattern, solution_str) + extracted_answer = match[-1] if match else "[INVALID]" + pred = normalize_final_answer(extracted_answer) + + # Process ground truth + if gt_need_extract: + gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt))) + else: + gt = normalize_final_answer(gt) + + return (pred == gt), pred + + +def is_correct_strict_box( + pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None +) -> tuple[int, Optional[str]]: + """Check if the prediction is correct using strict boxed answer criteria. + + Args: + pred: The prediction string + gt: The ground truth answer + pause_tokens_index: Indices of pause tokens + + Returns: + Tuple of (score, extracted_prediction) + """ + # Extract the relevant part of the prediction + if pause_tokens_index is not None: + assert len(pause_tokens_index) == 4 + pred = pred[pause_tokens_index[-1] - 100:] + else: + pred = pred[-100:] + + # Extract and check the boxed answer + boxed_pred = last_boxed_only_string(pred) + extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None + # print("==========", extracted_pred, gt) + + return 1 if (extracted_pred == gt) else -1, extracted_pred + + +def verify( + solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None +) -> bool: + """Verify if the solution is correct. + + Args: + solution_str: The solution string to verify + answer: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + True if the solution is correct, False otherwise + """ + if strict_box_verify: + correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index) + return correct == 1, pred + + correct, pred = is_correct_minerva(solution_str, answer) + return correct, pred + + +def compute_score( + solution_str: str, + ground_truth: str, + strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None, +) -> dict: + """Compute the reward score for a solution. + + Args: + solution_str: The solution string + ground_truth: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + Reward score (1.0 for correct, -1.0 for incorrect) + """ + # Limit solution length for efficiency + solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + + # Verify the solution + correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + + reward = 1.0 if correct else -1.0 + acc = correct + + return { + "score": reward, + "acc": acc + } + + +def compute_reward(response, label, extra_info): + predict_str = response + + eos_token = extra_info['eos_token'] + if response.endswith(eos_token): + response = response[: -len(eos_token)] + + out = compute_score(response, label) + reward = out['score'] + + overlong_reward = 0 + if extra_info.get("enable_overlong_buffer", None): + overlong_buffer_len = extra_info['overlong_buffer_len'] + expected_len = extra_info['max_response_len'] - overlong_buffer_len + valid_response_length = len(extra_info['tokenizer'](predict_str, return_tensors="pt")["input_ids"].flatten().tolist()) + exceed_len = valid_response_length - expected_len + overlong_penalty_factor = extra_info['overlong_penalty_factor'] + overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) + reward += overlong_reward + return {'score': reward, 'acc': out['acc']} + + +class DapoMathJudgerConfig(BaseModel): + extra_info: dict = Field(default={"score": 1, "format_score": 0, "eos_token": '<|endoftext|>'}) + enable_overlong_buffer: bool + max_response_len: Optional[int] = None + overlong_buffer_len: Optional[int] = None + overlong_penalty_factor: Optional[float] = None + tokenizer: Any = None + + def __init__(self, enable_overlong_buffer: bool, max_response_len: Optional[int], + overlong_buffer_len: Optional[int], overlong_penalty_factor: Optional[float], tokenizer: Any): + # 初始化基类 + super().__init__( + enable_overlong_buffer=enable_overlong_buffer, + max_response_len=max_response_len, + overlong_buffer_len=overlong_buffer_len, + overlong_penalty_factor=overlong_penalty_factor, + tokenizer=tokenizer + ) + + # 根据条件更新 extra_info + if enable_overlong_buffer: + assert max_response_len is not None + assert overlong_buffer_len is not None + assert overlong_penalty_factor is not None + assert tokenizer is not None + self.extra_info.update({ + "enable_overlong_buffer": enable_overlong_buffer, + "max_response_len": max_response_len, + "overlong_buffer_len": overlong_buffer_len, + "overlong_penalty_factor": overlong_penalty_factor, + "tokenizer": tokenizer, + }) + + def build(self): + return NativeJudger(reward_func=compute_reward, extra_info=self.extra_info) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 08d7059b9..42361384f 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -218,17 +218,15 @@ def __init__( if self._enable_evaluate and evaluator_config: self._evaluator = Evaluator.remote(evaluator_config, self._rollout_env_controller) # type: ignore[attr-defined] self._evaluator_sample_params = SampleParams( - top_p=1.0, - temperature=0.0, - do_sample=False, + top_p=0.7, + temperature=1.0, + do_sample=True, max_tokens=dataflow_config.sample_params.max_tokens, top_k=1, ) self._eval_step = evaluator_config.evaluate_step else: - self._evaluator = None - self._evaluator_sample_params = SampleParams() - self._eval_step = 0 + pass self._global_batch_size = dataflow_config.global_batch_size self._rollout_steps = ( From 3dfb684d9bda79ea1c37bb2f0f583832262f1de2 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 17 Sep 2025 15:23:26 +0800 Subject: [PATCH 02/22] update --- ci/scripts/test_dapo.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ci/scripts/test_dapo.sh b/ci/scripts/test_dapo.sh index 94f34da0f..d47393f9d 100644 --- a/ci/scripts/test_dapo.sh +++ b/ci/scripts/test_dapo.sh @@ -14,8 +14,9 @@ if [ ! -d "$OUTPUT_DIR" ]; then fi #export ROLLOUT_MODEL_PATH="/cpfs01/shared/llm_razor/huanghaian/code/verl/DAPO/DAPO-Qwen2.5-7b-MATH-0527a1/global_step_100/actor/huggingface" -export ROLLOUT_MODEL_PATH='/cpfs01/shared/llm_ddd/lishuaibin/ckpt/Qwen/Qwen2.5-Math-7B' + export ROLLOUT_MODEL_PATH='/cpfs01/shared/llm_ddd/lishuaibin/ckpt/Qwen/Qwen2.5-Math-7B' export ROLLOUT_DATA_PATH="/cpfs01/shared/llm_razor/lishuaibin/math_dapo_data/dapo-math-17k.jsonl" +export ROLLOUT_TEST_DATA_PATH="/cpfs01/shared/llm_razor/lishuaibin/math_dapo_data/aime-2024.jsonl" python ci/scripts/test_dapo_trainer.py \ --total-epochs 1 \ @@ -31,5 +32,5 @@ python ci/scripts/test_dapo_trainer.py \ --max-response-length 8192 \ --optimizer-disable-foreach \ --enable-evaluate \ - --eval-data-path /cpfs01/shared/llm_razor/lishuaibin/math_dapo_data/aime-2024.jsonl \ + --evaluate-step 5 \ 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" From d558d971f5f8f791069f7ac5753a4ace15edb28b Mon Sep 17 00:00:00 2001 From: huanghaian Date: Thu, 18 Sep 2025 10:52:40 +0800 Subject: [PATCH 03/22] fix ratio --- xtuner/v1/module/lm_head/lm_head.py | 4 +++- xtuner/v1/train/cli/grpo.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/xtuner/v1/module/lm_head/lm_head.py b/xtuner/v1/module/lm_head/lm_head.py index aa1adbacb..44b21c74d 100644 --- a/xtuner/v1/module/lm_head/lm_head.py +++ b/xtuner/v1/module/lm_head/lm_head.py @@ -33,7 +33,9 @@ def forward( # type: ignore[override] b = self.bias if loss_ctx is None: logits = F.linear(hidden_states, w, b) - return None, logits + # Note: the loss calculation will convert logits to float32, so for alignment, + # we also need to convert it to float32 here to prevent the ratio from being 1 during rl training + return None, logits.float() else: return loss_ctx.forward(hidden_states, w, b) diff --git a/xtuner/v1/train/cli/grpo.py b/xtuner/v1/train/cli/grpo.py index 41c6ec2b8..a8f178af4 100644 --- a/xtuner/v1/train/cli/grpo.py +++ b/xtuner/v1/train/cli/grpo.py @@ -3,7 +3,6 @@ import ray -from transformers import AutoTokenizer from xtuner.v1.config import ( AdamWConfig, FSDPConfig, @@ -102,7 +101,7 @@ def main(args): collator="fake_collator", pack_level="none", ) - tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + if eval_dataset_cfg: evaluator_cfg = EvaluatorConfig( dataset_cfg=eval_dataset_cfg, From 64d5e29bac744e986d7f2efbe835e1868f5b5bef Mon Sep 17 00:00:00 2001 From: huanghaian Date: Thu, 18 Sep 2025 14:41:16 +0800 Subject: [PATCH 04/22] debug --- xtuner/v1/engine/train_engine.py | 44 ++++++++++++++++++++------------ xtuner/v1/loss/base_loss_ctx.py | 12 ++++----- xtuner/v1/loss/chunk_loss.py | 6 +++-- xtuner/v1/model/base.py | 1 + xtuner/v1/model/dense/dense.py | 3 ++- xtuner/v1/rl/base/controller.py | 3 +++ xtuner/v1/rl/base/worker.py | 8 ++++++ xtuner/v1/rl/grpo/loss.py | 3 ++- 8 files changed, 53 insertions(+), 27 deletions(-) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 49540aa75..876002605 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -32,7 +32,7 @@ from xtuner.v1.model.base import BaseModel, ModelItem, TransformerConfig from xtuner.v1.module.router import NoAuxRouterConfig from xtuner.v1.utils import get_device, get_logger, get_torch_device_module - +from torch.distributed.nn.functional import all_reduce logger = get_logger() DEVICE = get_device() @@ -50,7 +50,7 @@ def profile_time_and_memory(desc): max_memory = torch_device.max_memory_allocated() cost_time = time.time() - start_t - logger.success(f"{desc} Elapsed time {cost_time:.2f} seconds, peak gpu memory {max_memory / 1024**3:.1f}G") + logger.success(f"{desc} Elapsed time {cost_time:.2f} seconds, peak gpu memory {max_memory / 1024 ** 3:.1f}G") threading_lock = threading.Lock() @@ -138,11 +138,11 @@ class TrainEngine: float8_handler: Optional[Float8Handler] def __init__( - self, - model_cfg: TransformerConfig, - optim_cfg: OptimConfig, - fsdp_cfg: FSDPConfig, - intra_layer_micro_batch: int = 1, + self, + model_cfg: TransformerConfig, + optim_cfg: OptimConfig, + fsdp_cfg: FSDPConfig, + intra_layer_micro_batch: int = 1, ) -> None: self.model_cfg = model_cfg self.optim_cfg = optim_cfg @@ -224,8 +224,8 @@ def train_step(self, data_batches: list[ModelItem]): iters_per_step = self.grad_accumulation_steps(len(data_batches)) moe_need_update_bias = ( - isinstance(getattr(self.model_cfg, "router", None), NoAuxRouterConfig) - and self.model_cfg.router.router_bias_update_speed > 0 + isinstance(getattr(self.model_cfg, "router", None), NoAuxRouterConfig) + and self.model_cfg.router.router_bias_update_speed > 0 ) if moe_need_update_bias: tokens_per_expert_global_for_bias = torch.zeros( @@ -245,8 +245,10 @@ def train_step(self, data_batches: list[ModelItem]): logger.info(f"grad_accumulation_steps: {iters_per_step}") self._count += 1 + grad_acc_loss = [] + max_ratio = [] for i in range(0, len(data_batches), intra_layer_micro_batch): - data_batch = data_batches[i : i + intra_layer_micro_batch] + data_batch = data_batches[i: i + intra_layer_micro_batch] seq_ctx_list = [] loss_ctx_list = [] for data in data_batch: @@ -268,9 +270,14 @@ def train_step(self, data_batches: list[ModelItem]): # llm loss has been global averaged llm_loss = output["loss"] - step_llm_loss += llm_loss.detach().clone() + grad_acc_loss.append(llm_loss.detach().clone().item()) + step_loss += llm_loss.detach().clone() + + if dist.is_initialized(): + llm_loss = all_reduce(llm_loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD) loss = llm_loss + step_llm_loss += llm_loss.detach().clone() if "balancing_loss" in output: balancing_loss = output["balancing_loss"] / iters_per_step @@ -296,6 +303,7 @@ def train_step(self, data_batches: list[ModelItem]): del output loss.backward() step_loss += loss.detach().clone() + grad_acc_loss.append(loss.detach().clone().item()) if moe_need_update_bias: avg_count_load = tokens_per_expert_global_for_bias.float().mean(1) @@ -319,6 +327,8 @@ def train_step(self, data_batches: list[ModelItem]): dist.all_reduce(reduced_z_loss.div_(dist.get_world_size())) loss_log["reduced_z_loss"] = reduced_z_loss.item() other_log["consumed_tokens"] = step_consumed_tokens.item() + other_log["grad_acc_loss"] = grad_acc_loss + other_log["max_ratio"] = max_ratio return loss_log, other_log def from_hf(self, hf_path: str | Path, strict: bool = False): @@ -349,7 +359,7 @@ def cal_total_norm(self, tensors: List[DTensor], norm_type: float = 2.0, foreach device = tensors[0].device norms: Tuple[DTensor, ...] if (foreach is None and _has_foreach_support(tensors, device)) or ( # type: ignore - foreach and _device_has_foreach_support(device) + foreach and _device_has_foreach_support(device) ): norms = torch._foreach_norm(tensors, norm_type) # type: ignore elif foreach: @@ -361,7 +371,7 @@ def cal_total_norm(self, tensors: List[DTensor], norm_type: float = 2.0, foreach torch.stack([norm.to_local() for norm in norms]), norm_type, dtype=torch.float32 ) if norm_type == 2: - local_norm_squared = local_norm**2 + local_norm_squared = local_norm ** 2 for i, placement in enumerate(placements): if isinstance(placement, Shard): # When using ep + fsdp, the placement corresponding to fsdp mesh is _StridedShard @@ -371,7 +381,7 @@ def cal_total_norm(self, tensors: List[DTensor], norm_type: float = 2.0, foreach pass else: raise ValueError(f"Unsupported placement type {placement} in clip_grad_norm") - global_norm = local_norm_squared**0.5 + global_norm = local_norm_squared ** 0.5 else: raise NotImplementedError return global_norm @@ -426,9 +436,9 @@ def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): # TODO: Support async save def save_dcp( - self, - model_dir: Path, - optimizer_dir: Path | None = None, + self, + model_dir: Path, + optimizer_dir: Path | None = None, ): rank = dist.get_rank() diff --git a/xtuner/v1/loss/base_loss_ctx.py b/xtuner/v1/loss/base_loss_ctx.py index d9bd8ddbf..f34159ce9 100644 --- a/xtuner/v1/loss/base_loss_ctx.py +++ b/xtuner/v1/loss/base_loss_ctx.py @@ -129,8 +129,8 @@ def chunk_mode( assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode" chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size) - loss = ChunkLoss.apply(hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size) - return loss, None + loss, max_ratio = ChunkLoss.apply(hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size) + return loss, None, max_ratio def forward( self, @@ -145,9 +145,9 @@ def forward( if self.loss_cfg.mode == "eager": loss, logits = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) else: - loss, logits = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) + loss, logits, max_ratio = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) # Step 2.c in the loss calculation - if dist.is_initialized(): - loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD) - return loss, logits + # if dist.is_initialized(): + # loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD) + return loss, logits, max_ratio diff --git a/xtuner/v1/loss/chunk_loss.py b/xtuner/v1/loss/chunk_loss.py index bf86c6276..d9bc948df 100644 --- a/xtuner/v1/loss/chunk_loss.py +++ b/xtuner/v1/loss/chunk_loss.py @@ -25,20 +25,22 @@ def forward( grad_inputs_chunks = torch.split(grad_inputs, chunk_size, dim=1) hidden_states_chunks = torch.split(hidden_states, chunk_size, dim=1) + max_ratio = [] for i in range(len(hidden_states_chunks)): hidden_states_chunk = hidden_states_chunks[i] grad_inputs_chunk = grad_inputs_chunks[i] - (chunk_grad_input, chunk_grad_weight), (chunk_loss, _) = torch.func.grad_and_value( + (chunk_grad_input, chunk_grad_weight), (chunk_loss, _, ratio) = torch.func.grad_and_value( loss_forward, argnums=(0, 1), has_aux=True )(hidden_states_chunk, head_weight, None, loss_kwargs_chunks[i]) accumulated_loss.add_(chunk_loss) grad_inputs_chunk.copy_(chunk_grad_input) grad_weight.add_(chunk_grad_weight) + max_ratio.append(ratio) ctx.save_for_backward(grad_inputs, grad_weight) - return accumulated_loss + return accumulated_loss, torch.stack(max_ratio).max() @staticmethod def backward(ctx, *grad_output): diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 24acab5dc..dd3dc21d9 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -139,6 +139,7 @@ class ModelOutputs(TypedDict): hidden_states: NotRequired[list[torch.Tensor]] logits: NotRequired[torch.Tensor] loss: torch.Tensor + max_ratio: torch.Tensor def _is_float8_available(): diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index 9ae7009e0..d4bd028fd 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -89,9 +89,10 @@ def forward( hidden_states = self.norm(hidden_states) - loss, logits = self.lm_head(hidden_states, loss_ctx) # type: ignore + loss, logits, max_ratio = self.lm_head(hidden_states, loss_ctx) # type: ignore output["loss"] = loss output["logits"] = logits + output["max_ratio"] = max_ratio return ModelOutputs(**output) # type: ignore[typeddict-item] def build_embeddings(self, config: TransformerConfig): diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index e12ee7954..afe3bb069 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -163,7 +163,10 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: print(f"len(packed_data_batches): {len(packed_data_batches)}") handles = [] + index = list(range(len(packed_data_batches))) for worker_idx, worker in enumerate(self.workers): + _index = index[(worker_idx // data_replicate_size):: dp_size] + print(f"worker_idx: {worker_idx}, index: {_index}") handles.append( worker.fit.remote( # type: ignore[attr-defined] data_batches=packed_data_batches[(worker_idx // data_replicate_size) :: dp_size], diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index d345583ce..ef6d3d577 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -327,6 +327,13 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): data_batches=engine_input, ) grad_norm = self._engine.clip_grad_norm() + + if i == 0 and grad_norm > 100: + logger.info(f"{loss_log['total_loss'], other_log['grad_acc_loss'], other_log['max_ratio']}") + torch.save(engine_input, f'./temp/engine_input_rank_{self.rank}.pth') + self.save_hf(f'./temp/model') + raise RuntimeError('DEBUG 退出') + self._engine.step_optimizer(grad_norm) log_info = dict() log_info.update(loss_log) @@ -339,6 +346,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): log_str = f"Rollout {rollout_idx} Step {i}: " + log_str logger.info(log_str) + def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): self._engine.save_hf(hf_dir, save_dtype) diff --git a/xtuner/v1/rl/grpo/loss.py b/xtuner/v1/rl/grpo/loss.py index 4fd50c1a3..ee0d39c91 100644 --- a/xtuner/v1/rl/grpo/loss.py +++ b/xtuner/v1/rl/grpo/loss.py @@ -144,6 +144,7 @@ def loss_fn( policy_loss_weight, self.loss_cfg.policy_loss_cfg, ) + ratio = (logprobs - old_logprobs.detach()).exp() if self.loss_cfg.use_kl_loss: ref_logprobs = loss_kwargs.ref_logprobs @@ -154,4 +155,4 @@ def loss_fn( kl_loss = kl_penalty(logprobs, ref_logprobs, kl_loss_weight, self.loss_cfg.kl_loss_type) loss = loss + kl_loss - return loss, logits + return loss, logits, ratio.max() From 2d7d7281921ab374df656fd24f130cc35d8444a6 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Thu, 18 Sep 2025 15:22:03 +0800 Subject: [PATCH 05/22] debug --- xtuner/v1/engine/train_engine.py | 1 + xtuner/v1/loss/chunk_loss.py | 2 +- xtuner/v1/model/dense/dense.py | 3 ++- xtuner/v1/module/lm_head/lm_head.py | 2 +- xtuner/v1/rl/grpo/loss.py | 2 +- 5 files changed, 6 insertions(+), 4 deletions(-) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 876002605..f53802bf6 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -270,6 +270,7 @@ def train_step(self, data_batches: list[ModelItem]): # llm loss has been global averaged llm_loss = output["loss"] + max_ratio.append(output["max_ratio"].item()) grad_acc_loss.append(llm_loss.detach().clone().item()) step_loss += llm_loss.detach().clone() diff --git a/xtuner/v1/loss/chunk_loss.py b/xtuner/v1/loss/chunk_loss.py index d9bc948df..f899a93b0 100644 --- a/xtuner/v1/loss/chunk_loss.py +++ b/xtuner/v1/loss/chunk_loss.py @@ -30,7 +30,7 @@ def forward( hidden_states_chunk = hidden_states_chunks[i] grad_inputs_chunk = grad_inputs_chunks[i] - (chunk_grad_input, chunk_grad_weight), (chunk_loss, _, ratio) = torch.func.grad_and_value( + (chunk_grad_input, chunk_grad_weight), (chunk_loss, (_, ratio)) = torch.func.grad_and_value( loss_forward, argnums=(0, 1), has_aux=True )(hidden_states_chunk, head_weight, None, loss_kwargs_chunks[i]) diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index d4bd028fd..22a5d7114 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -92,7 +92,8 @@ def forward( loss, logits, max_ratio = self.lm_head(hidden_states, loss_ctx) # type: ignore output["loss"] = loss output["logits"] = logits - output["max_ratio"] = max_ratio + if max_ratio is not None: + output["max_ratio"] = max_ratio return ModelOutputs(**output) # type: ignore[typeddict-item] def build_embeddings(self, config: TransformerConfig): diff --git a/xtuner/v1/module/lm_head/lm_head.py b/xtuner/v1/module/lm_head/lm_head.py index 44b21c74d..2daf85379 100644 --- a/xtuner/v1/module/lm_head/lm_head.py +++ b/xtuner/v1/module/lm_head/lm_head.py @@ -35,7 +35,7 @@ def forward( # type: ignore[override] logits = F.linear(hidden_states, w, b) # Note: the loss calculation will convert logits to float32, so for alignment, # we also need to convert it to float32 here to prevent the ratio from being 1 during rl training - return None, logits.float() + return None, logits.float(), None else: return loss_ctx.forward(hidden_states, w, b) diff --git a/xtuner/v1/rl/grpo/loss.py b/xtuner/v1/rl/grpo/loss.py index ee0d39c91..34dc9602a 100644 --- a/xtuner/v1/rl/grpo/loss.py +++ b/xtuner/v1/rl/grpo/loss.py @@ -155,4 +155,4 @@ def loss_fn( kl_loss = kl_penalty(logprobs, ref_logprobs, kl_loss_weight, self.loss_cfg.kl_loss_type) loss = loss + kl_loss - return loss, logits, ratio.max() + return loss, (logits, ratio.max()) From 081a3810539fbc41dd154d80082c84e44500140c Mon Sep 17 00:00:00 2001 From: huanghaian Date: Thu, 18 Sep 2025 16:43:18 +0800 Subject: [PATCH 06/22] fix --- xtuner/v1/engine/train_engine.py | 2 -- xtuner/v1/rl/base/controller.py | 3 --- xtuner/v1/rl/base/worker.py | 2 +- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index f53802bf6..f6b2615e9 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -303,8 +303,6 @@ def train_step(self, data_batches: list[ModelItem]): del output loss.backward() - step_loss += loss.detach().clone() - grad_acc_loss.append(loss.detach().clone().item()) if moe_need_update_bias: avg_count_load = tokens_per_expert_global_for_bias.float().mean(1) diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index afe3bb069..e12ee7954 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -163,10 +163,7 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: print(f"len(packed_data_batches): {len(packed_data_batches)}") handles = [] - index = list(range(len(packed_data_batches))) for worker_idx, worker in enumerate(self.workers): - _index = index[(worker_idx // data_replicate_size):: dp_size] - print(f"worker_idx: {worker_idx}, index: {_index}") handles.append( worker.fit.remote( # type: ignore[attr-defined] data_batches=packed_data_batches[(worker_idx // data_replicate_size) :: dp_size], diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index ef6d3d577..b9ad192d7 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -329,7 +329,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): grad_norm = self._engine.clip_grad_norm() if i == 0 and grad_norm > 100: - logger.info(f"{loss_log['total_loss'], other_log['grad_acc_loss'], other_log['max_ratio']}") + logger.info(f"{loss_log['total_loss'], other_log['grad_acc_loss'], other_log['max_ratio'], grad_norm}") torch.save(engine_input, f'./temp/engine_input_rank_{self.rank}.pth') self.save_hf(f'./temp/model') raise RuntimeError('DEBUG 退出') From 4ab52172caf195d07c6f9342abc5e7cbb112f661 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 23 Sep 2025 17:47:50 +0800 Subject: [PATCH 07/22] update --- ci/scripts/test_dapo.sh | 11 ++--- xtuner/v1/model/__init__.py | 3 ++ xtuner/v1/model/dense/qwen2.py | 71 ++++++++++++++++++++++++++++++++- xtuner/v1/rl/base/controller.py | 2 +- xtuner/v1/rl/base/worker.py | 10 ++--- xtuner/v1/train/rl_trainer.py | 22 +++++----- 6 files changed, 91 insertions(+), 28 deletions(-) diff --git a/ci/scripts/test_dapo.sh b/ci/scripts/test_dapo.sh index d47393f9d..de7c07ca1 100644 --- a/ci/scripts/test_dapo.sh +++ b/ci/scripts/test_dapo.sh @@ -2,21 +2,18 @@ set -ex export XTUNER_USE_LMDEPLOY=1 export XTUNER_USE_FA3=1 -export PYTHONPATH=/cpfs01/shared/llm_razor/huanghaian/code/lmdeploy/:$PYTHONPATH +export PYTHONPATH=/mnt/shared-storage-user/huanghaian/code/lmdeploy/:$PYTHONPATH export UVICORN_LOG_LEVEL="CRITICAl" export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' -# OUTPUT_DIR='work_dirs/dense_8b_gsm8k_grpo_fix_shuaibin' -OUTPUT_DIR='work_dirs/dapo_math_debug_8B' OUTPUT_DIR='work_dirs/dapo_math_7B_newlmdeploy_nogroup' if [ ! -d "$OUTPUT_DIR" ]; then mkdir -p "$OUTPUT_DIR" fi -#export ROLLOUT_MODEL_PATH="/cpfs01/shared/llm_razor/huanghaian/code/verl/DAPO/DAPO-Qwen2.5-7b-MATH-0527a1/global_step_100/actor/huggingface" - export ROLLOUT_MODEL_PATH='/cpfs01/shared/llm_ddd/lishuaibin/ckpt/Qwen/Qwen2.5-Math-7B' -export ROLLOUT_DATA_PATH="/cpfs01/shared/llm_razor/lishuaibin/math_dapo_data/dapo-math-17k.jsonl" -export ROLLOUT_TEST_DATA_PATH="/cpfs01/shared/llm_razor/lishuaibin/math_dapo_data/aime-2024.jsonl" +export ROLLOUT_MODEL_PATH='/mnt/shared-storage-user/llmrazor-share/model/Qwen2.5-Math-7B' +export ROLLOUT_DATA_PATH="/mnt/shared-storage-user/huanghaian/code/verl/data/dapo_math/dapo-math-17k.jsonl" +export ROLLOUT_TEST_DATA_PATH="/mnt/shared-storage-user/huanghaian/code/verl/data/dapo_math/aime-2024.jsonl" python ci/scripts/test_dapo_trainer.py \ --total-epochs 1 \ diff --git a/xtuner/v1/model/__init__.py b/xtuner/v1/model/__init__.py index 25c1ecbc2..014868152 100644 --- a/xtuner/v1/model/__init__.py +++ b/xtuner/v1/model/__init__.py @@ -10,6 +10,7 @@ from .compose.internvl import InternVL3P5Dense8BConfig, InternVL3P5MoE30BA3Config, InternVLBaseConfig from .dense.dense import Dense from .dense.qwen3 import Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig +from .dense.qwen2 import Qwen2Dense7BConfig, Qwen2DenseConfig from .moe.deepseek_v3 import DeepSeekV3Config from .moe.gpt_oss import GptOss21BA3P6Config, GptOss117BA5P8Config, GptOssConfig from .moe.moe import BalancingLossConfig, MoE, MoEModelOutputs, ZLossConfig @@ -42,6 +43,8 @@ def get_model_config_from_hf(model_path: Path): return Qwen3MoEConfig.from_hf(model_path) elif cfg.model_type == "qwen3": return Qwen3DenseConfig.from_hf(model_path) + elif cfg.model_type == "qwen2": + return Qwen2DenseConfig.from_hf(model_path) elif cfg.model_type == "gpt_oss": return GptOssConfig.from_hf(model_path) elif cfg.model_type == "deepseek_v3": diff --git a/xtuner/v1/model/dense/qwen2.py b/xtuner/v1/model/dense/qwen2.py index ecc947afb..dc6a4bf90 100644 --- a/xtuner/v1/model/dense/qwen2.py +++ b/xtuner/v1/model/dense/qwen2.py @@ -2,6 +2,10 @@ from xtuner.v1.model.base import TransformerConfig from xtuner.v1.module.attention import MHAConfig +from transformers.models.qwen2 import Qwen2Config as HFQwen2DenseConfig +from pathlib import Path +import torch +from typing_extensions import Self from .dense import Dense @@ -29,6 +33,69 @@ class Qwen2DenseConfig(TransformerConfig): def build(self) -> Qwen2Dense: return Qwen2Dense(self) + @classmethod + def from_hf(cls, hf_path: str | Path) -> Self: + from transformers import AutoConfig + from transformers.models.qwen2 import Qwen2Config as HFConfig + + hf_config = AutoConfig.from_pretrained(hf_path, trust_remote_code=True) + + assert isinstance(hf_config, HFConfig) + + config = cls( + hf_config=hf_config, + vocab_size=hf_config.vocab_size, + max_position_embeddings=hf_config.max_position_embeddings, + pad_token_id=hf_config.eos_token_id, + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + num_hidden_layers=hf_config.num_hidden_layers, + max_window_layers=hf_config.max_window_layers, + hidden_size=hf_config.hidden_size, + intermediate_size=hf_config.intermediate_size, + rms_norm_eps=hf_config.rms_norm_eps, + rope_theta=hf_config.rope_theta, + hidden_act=hf_config.hidden_act, + attention=MHAConfig( + num_attention_heads=hf_config.num_attention_heads, + num_key_value_heads=hf_config.num_key_value_heads, + head_dim=128, + sliding_window=hf_config.sliding_window, + qk_norm=False, + qkv_bias=True, + ), + use_sliding_window=hf_config.use_sliding_window, + tie_word_embeddings=hf_config.tie_word_embeddings, + ) + + return config + + @property + def hf_config(self) -> HFQwen2DenseConfig: + """Check if the configuration can be saved in HuggingFace format.""" + return HFQwen2DenseConfig( + architectures=["Qwen2ForCausalLM"], + vocab_size=self.vocab_size, + max_position_embeddings=self.max_position_embeddings, + max_window_layers=self.max_window_layers, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + num_hidden_layers=self.num_hidden_layers, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + rms_norm_eps=self.rms_norm_eps, + rope_theta=self.rope_theta, + hidden_act=self.hidden_act, + num_attention_heads=self.attention.num_attention_heads, + num_key_value_heads=self.attention.num_key_value_heads, + head_dim=self.attention.head_dim, + sliding_window=self.attention.sliding_window, + use_sliding_window=self.use_sliding_window, + tie_word_embeddings=self.tie_word_embeddings, + dtype=torch.bfloat16, + ) + # TODO: Unify the config name style class Qwen2Dense7BConfig(Qwen2DenseConfig): @@ -48,6 +115,6 @@ class Qwen2Dense7BConfig(Qwen2DenseConfig): head_dim=128, qk_norm=False, qkv_bias=True, - ) + ) # sliding_window= 4096 - tie_word_embeddings: bool = False \ No newline at end of file + tie_word_embeddings: bool = False diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index e12ee7954..5c0f3e11b 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -119,7 +119,7 @@ def _grouped_by_max_length(self, packed_data_batches): def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int): packed_data_batches = self._packing(data_batches, pack_max_length) - packed_data_batches = self._grouped_by_max_length(packed_data_batches) + # packed_data_batches = self._grouped_by_max_length(packed_data_batches) # todo: support round up num_packed_data_batches = len(packed_data_batches) diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index b9ad192d7..f808406df 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -328,11 +328,11 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): ) grad_norm = self._engine.clip_grad_norm() - if i == 0 and grad_norm > 100: - logger.info(f"{loss_log['total_loss'], other_log['grad_acc_loss'], other_log['max_ratio'], grad_norm}") - torch.save(engine_input, f'./temp/engine_input_rank_{self.rank}.pth') - self.save_hf(f'./temp/model') - raise RuntimeError('DEBUG 退出') + # if i == 0 and grad_norm > 100: + # logger.info(f"{loss_log['total_loss'], other_log['grad_acc_loss'], other_log['max_ratio'], grad_norm}") + # torch.save(engine_input, f'./temp/engine_input_rank_{self.rank}.pth') + # self.save_hf(f'./temp/model') + # raise RuntimeError('DEBUG 退出') self._engine.step_optimizer(grad_norm) log_info = dict() diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 42361384f..336c59f5c 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -348,21 +348,17 @@ def _prepare_train_data(self, data_groups, pack_max_length): return data_batches def _save_trajectories(self, data_groups, save_path): - with open(save_path, "w") as f: + with open(save_path, "w", encoding='utf-8') as f: for group in data_groups: - response_list = [] - reward_list = [] for data in group: - response_list.append(data["response_str"]) - reward_list.append(data["reward"]) - item = { - "messages": group[0]["messages"], - "response": response_list, - "label": group[0]["reward_model"]["ground_truth"], - "reward": reward_list, - } - json.dump(item, f) - f.write("\n") + item = { + "messages": data["messages"], + "response": data["response_str"], + "label": data["reward_model"]["ground_truth"], + "reward": data["reward"], + } + json.dump(item, f, ensure_ascii=False, indent=2) + f.write("\n") def _load_trajectories(self, save_path): data_groups = [] From 444aae2144a06926394276f404e6df04c98592c3 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 23 Sep 2025 18:11:47 +0800 Subject: [PATCH 08/22] update1 --- ci/scripts/test_dapo.sh | 2 ++ xtuner/v1/ray/evaluator.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ci/scripts/test_dapo.sh b/ci/scripts/test_dapo.sh index de7c07ca1..c446a6090 100644 --- a/ci/scripts/test_dapo.sh +++ b/ci/scripts/test_dapo.sh @@ -15,6 +15,8 @@ export ROLLOUT_MODEL_PATH='/mnt/shared-storage-user/llmrazor-share/model/Qwen2.5 export ROLLOUT_DATA_PATH="/mnt/shared-storage-user/huanghaian/code/verl/data/dapo_math/dapo-math-17k.jsonl" export ROLLOUT_TEST_DATA_PATH="/mnt/shared-storage-user/huanghaian/code/verl/data/dapo_math/aime-2024.jsonl" +ray stop --force + python ci/scripts/test_dapo_trainer.py \ --total-epochs 1 \ --work-dir "$OUTPUT_DIR" \ diff --git a/xtuner/v1/ray/evaluator.py b/xtuner/v1/ray/evaluator.py index 180486e26..f011d8fd2 100644 --- a/xtuner/v1/ray/evaluator.py +++ b/xtuner/v1/ray/evaluator.py @@ -240,8 +240,8 @@ async def run(self, sample_params: Optional[SampleParams] = None, return_samples self.dataloader = iter(self.dataset) self.sample_params = sample_params if sample_params else SampleParams() # set greedy sample for evaluator - self.sample_params.temperature = 0.0 - self.sample_params.top_k = 1 + # self.sample_params.temperature = 0.0 + # self.sample_params.top_k = 1 ray.get(self.env_controller.restart.remote()) # type: ignore[attr-defined] await self.concurrent_eval_task_runner() scores = self.compute_metric(self.return_list) From e726881f6316ee1002619d3ca904a4f1808924fb Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 23 Sep 2025 18:29:41 +0800 Subject: [PATCH 09/22] update1 --- xtuner/v1/train/rl_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 336c59f5c..266d786c6 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -351,9 +351,11 @@ def _save_trajectories(self, data_groups, save_path): with open(save_path, "w", encoding='utf-8') as f: for group in data_groups: for data in group: + response_ids = self.tokenizer.encode(data["response_str"], add_special_tokens=False) item = { "messages": data["messages"], "response": data["response_str"], + "response_len": len(response_ids), "label": data["reward_model"]["ground_truth"], "reward": data["reward"], } From 9ace2b920aefbba62234cc62c238129c2369f81d Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 23 Sep 2025 22:09:37 +0800 Subject: [PATCH 10/22] fix bos --- xtuner/v1/model/dense/qwen2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xtuner/v1/model/dense/qwen2.py b/xtuner/v1/model/dense/qwen2.py index dc6a4bf90..1ee3c4a47 100644 --- a/xtuner/v1/model/dense/qwen2.py +++ b/xtuner/v1/model/dense/qwen2.py @@ -101,6 +101,7 @@ def hf_config(self) -> HFQwen2DenseConfig: class Qwen2Dense7BConfig(Qwen2DenseConfig): vocab_size: int = 152064 max_position_embeddings: int = 32768 + bos_token_id: int = 151643 pad_token_id: int = 151643 # eos_id eos_token_id: int = 151643 # eos_id num_hidden_layers: int = 28 From 6ea5b8237e27b1f40f69842b2d9cd6b72e9d7b63 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 24 Sep 2025 10:37:25 +0800 Subject: [PATCH 11/22] refine --- ci/scripts/test_dapo.sh | 1 + ci/scripts/test_dapo_trainer.py | 4 +- xtuner/v1/engine/train_engine.py | 33 ++++++------- xtuner/v1/loss/base_loss_ctx.py | 6 +-- xtuner/v1/model/__init__.py | 2 +- xtuner/v1/model/dense/qwen2.py | 9 ++-- xtuner/v1/ray/judger/controller.py | 4 +- xtuner/v1/ray/judger/dapo_math.py | 77 ++++++++++++++++-------------- xtuner/v1/rl/base/worker.py | 1 - xtuner/v1/rl/loss_fn.py | 27 +++++++---- xtuner/v1/train/rl_trainer.py | 66 +++++++++++++++++++++++-- 11 files changed, 153 insertions(+), 77 deletions(-) diff --git a/ci/scripts/test_dapo.sh b/ci/scripts/test_dapo.sh index c446a6090..6e0c5ac54 100644 --- a/ci/scripts/test_dapo.sh +++ b/ci/scripts/test_dapo.sh @@ -32,4 +32,5 @@ python ci/scripts/test_dapo_trainer.py \ --optimizer-disable-foreach \ --enable-evaluate \ --evaluate-step 5 \ + --hf-interval 50 \ 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" diff --git a/ci/scripts/test_dapo_trainer.py b/ci/scripts/test_dapo_trainer.py index 77ba1b6ea..25dce262c 100644 --- a/ci/scripts/test_dapo_trainer.py +++ b/ci/scripts/test_dapo_trainer.py @@ -63,6 +63,7 @@ def parse_args(): parser.add_argument("--enable-evaluate", action="store_true") parser.add_argument("--evaluate-step", type=int, default=1) parser.add_argument("--evaluate-ratio", type=float, default=1) + parser.add_argument("--hf-interval", type=float, default=50) parser.add_argument("--ray-cluster-url", type=str, default="") return parser.parse_args() @@ -196,7 +197,8 @@ def main(args): tokenizer_path=args.model_path, work_dir=args.work_dir, total_epochs=args.total_epochs, - enable_evaluate=args.enable_evaluate + enable_evaluate=args.enable_evaluate, + hf_interval=args.hf_interval ) trainer.fit() diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index f6b2615e9..09c5234d1 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -20,6 +20,7 @@ set_optimizer_state_dict, ) from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.nn.functional import all_reduce from torch.distributed.tensor.placement_types import Placement from torch.utils._foreach_utils import ( _device_has_foreach_support, @@ -32,7 +33,7 @@ from xtuner.v1.model.base import BaseModel, ModelItem, TransformerConfig from xtuner.v1.module.router import NoAuxRouterConfig from xtuner.v1.utils import get_device, get_logger, get_torch_device_module -from torch.distributed.nn.functional import all_reduce + logger = get_logger() DEVICE = get_device() @@ -50,7 +51,7 @@ def profile_time_and_memory(desc): max_memory = torch_device.max_memory_allocated() cost_time = time.time() - start_t - logger.success(f"{desc} Elapsed time {cost_time:.2f} seconds, peak gpu memory {max_memory / 1024 ** 3:.1f}G") + logger.success(f"{desc} Elapsed time {cost_time:.2f} seconds, peak gpu memory {max_memory / 1024**3:.1f}G") threading_lock = threading.Lock() @@ -138,11 +139,11 @@ class TrainEngine: float8_handler: Optional[Float8Handler] def __init__( - self, - model_cfg: TransformerConfig, - optim_cfg: OptimConfig, - fsdp_cfg: FSDPConfig, - intra_layer_micro_batch: int = 1, + self, + model_cfg: TransformerConfig, + optim_cfg: OptimConfig, + fsdp_cfg: FSDPConfig, + intra_layer_micro_batch: int = 1, ) -> None: self.model_cfg = model_cfg self.optim_cfg = optim_cfg @@ -224,8 +225,8 @@ def train_step(self, data_batches: list[ModelItem]): iters_per_step = self.grad_accumulation_steps(len(data_batches)) moe_need_update_bias = ( - isinstance(getattr(self.model_cfg, "router", None), NoAuxRouterConfig) - and self.model_cfg.router.router_bias_update_speed > 0 + isinstance(getattr(self.model_cfg, "router", None), NoAuxRouterConfig) + and self.model_cfg.router.router_bias_update_speed > 0 ) if moe_need_update_bias: tokens_per_expert_global_for_bias = torch.zeros( @@ -248,7 +249,7 @@ def train_step(self, data_batches: list[ModelItem]): grad_acc_loss = [] max_ratio = [] for i in range(0, len(data_batches), intra_layer_micro_batch): - data_batch = data_batches[i: i + intra_layer_micro_batch] + data_batch = data_batches[i : i + intra_layer_micro_batch] seq_ctx_list = [] loss_ctx_list = [] for data in data_batch: @@ -358,7 +359,7 @@ def cal_total_norm(self, tensors: List[DTensor], norm_type: float = 2.0, foreach device = tensors[0].device norms: Tuple[DTensor, ...] if (foreach is None and _has_foreach_support(tensors, device)) or ( # type: ignore - foreach and _device_has_foreach_support(device) + foreach and _device_has_foreach_support(device) ): norms = torch._foreach_norm(tensors, norm_type) # type: ignore elif foreach: @@ -370,7 +371,7 @@ def cal_total_norm(self, tensors: List[DTensor], norm_type: float = 2.0, foreach torch.stack([norm.to_local() for norm in norms]), norm_type, dtype=torch.float32 ) if norm_type == 2: - local_norm_squared = local_norm ** 2 + local_norm_squared = local_norm**2 for i, placement in enumerate(placements): if isinstance(placement, Shard): # When using ep + fsdp, the placement corresponding to fsdp mesh is _StridedShard @@ -380,7 +381,7 @@ def cal_total_norm(self, tensors: List[DTensor], norm_type: float = 2.0, foreach pass else: raise ValueError(f"Unsupported placement type {placement} in clip_grad_norm") - global_norm = local_norm_squared ** 0.5 + global_norm = local_norm_squared**0.5 else: raise NotImplementedError return global_norm @@ -435,9 +436,9 @@ def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): # TODO: Support async save def save_dcp( - self, - model_dir: Path, - optimizer_dir: Path | None = None, + self, + model_dir: Path, + optimizer_dir: Path | None = None, ): rank = dist.get_rank() diff --git a/xtuner/v1/loss/base_loss_ctx.py b/xtuner/v1/loss/base_loss_ctx.py index f34159ce9..0d96cbfe8 100644 --- a/xtuner/v1/loss/base_loss_ctx.py +++ b/xtuner/v1/loss/base_loss_ctx.py @@ -3,12 +3,10 @@ from typing import Annotated, Generic, Literal, TypeVar import torch -import torch.distributed as dist import torch.nn as nn from cyclopts import Parameter from pydantic import BaseModel, ConfigDict from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.nn.functional import all_reduce from .chunk_loss import ChunkLoss @@ -129,7 +127,9 @@ def chunk_mode( assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode" chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size) - loss, max_ratio = ChunkLoss.apply(hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size) + loss, max_ratio = ChunkLoss.apply( + hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size + ) return loss, None, max_ratio def forward( diff --git a/xtuner/v1/model/__init__.py b/xtuner/v1/model/__init__.py index 014868152..940f3df53 100644 --- a/xtuner/v1/model/__init__.py +++ b/xtuner/v1/model/__init__.py @@ -9,8 +9,8 @@ from .compose.intern_s1 import InternS1BaseConfig, InternS1Config, InternS1MiniConfig from .compose.internvl import InternVL3P5Dense8BConfig, InternVL3P5MoE30BA3Config, InternVLBaseConfig from .dense.dense import Dense -from .dense.qwen3 import Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig from .dense.qwen2 import Qwen2Dense7BConfig, Qwen2DenseConfig +from .dense.qwen3 import Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig from .moe.deepseek_v3 import DeepSeekV3Config from .moe.gpt_oss import GptOss21BA3P6Config, GptOss117BA5P8Config, GptOssConfig from .moe.moe import BalancingLossConfig, MoE, MoEModelOutputs, ZLossConfig diff --git a/xtuner/v1/model/dense/qwen2.py b/xtuner/v1/model/dense/qwen2.py index 1ee3c4a47..b67c80aae 100644 --- a/xtuner/v1/model/dense/qwen2.py +++ b/xtuner/v1/model/dense/qwen2.py @@ -1,12 +1,13 @@ import re - -from xtuner.v1.model.base import TransformerConfig -from xtuner.v1.module.attention import MHAConfig -from transformers.models.qwen2 import Qwen2Config as HFQwen2DenseConfig from pathlib import Path + import torch from typing_extensions import Self +from transformers.models.qwen2 import Qwen2Config as HFQwen2DenseConfig +from xtuner.v1.model.base import TransformerConfig +from xtuner.v1.module.attention import MHAConfig + from .dense import Dense diff --git a/xtuner/v1/ray/judger/controller.py b/xtuner/v1/ray/judger/controller.py index ab27a9cee..59e4f7931 100644 --- a/xtuner/v1/ray/judger/controller.py +++ b/xtuner/v1/ray/judger/controller.py @@ -167,8 +167,8 @@ async def run( for i in range(num_samples): for name, scores in rewards_by_name.items(): weight = data_source.get(name, 1.0) - final_rewards[i] += scores[i]['score'] * weight - acc_list.append(scores[i]['acc']) + final_rewards[i] += scores[i]["score"] * weight + acc_list.append(scores[i]["acc"]) assert len(final_rewards) == num_samples for i, item in enumerate(group_data_item): diff --git a/xtuner/v1/ray/judger/dapo_math.py b/xtuner/v1/ray/judger/dapo_math.py index 785ad9d05..b1cdc4b3a 100644 --- a/xtuner/v1/ray/judger/dapo_math.py +++ b/xtuner/v1/ray/judger/dapo_math.py @@ -1,10 +1,11 @@ import re +from typing import Any, Optional from pydantic import BaseModel, Field -from typing import Any, Optional from .native import NativeJudger + # _SOLUTION_CLIP_CHARS = 300 @@ -55,9 +56,6 @@ # limitations under the License. # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py -import re -from typing import Optional - def last_boxed_only_string(string: str) -> Optional[str]: """Extract the last LaTeX boxed expression from a string. @@ -86,7 +84,7 @@ def last_boxed_only_string(string: str) -> Optional[str]: break i += 1 - return string[idx: right_brace_idx + 1] if right_brace_idx is not None else None + return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None def remove_boxed(s: str) -> str: @@ -101,7 +99,7 @@ def remove_boxed(s: str) -> str: left = "\\boxed{" assert s[: len(left)] == left, f"box error: {s}" assert s[-1] == "}", f"box error: {s}" - return s[len(left): -1] + return s[len(left) : -1] # Constants for normalization @@ -205,7 +203,7 @@ def normalize_final_answer(final_answer: str) -> str: def is_correct_minerva( - solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" + solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" ) -> tuple[bool, str]: """Check if the solution is correct according to Minerva criteria. @@ -233,7 +231,7 @@ def is_correct_minerva( def is_correct_strict_box( - pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None + pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None ) -> tuple[int, Optional[str]]: """Check if the prediction is correct using strict boxed answer criteria. @@ -248,7 +246,7 @@ def is_correct_strict_box( # Extract the relevant part of the prediction if pause_tokens_index is not None: assert len(pause_tokens_index) == 4 - pred = pred[pause_tokens_index[-1] - 100:] + pred = pred[pause_tokens_index[-1] - 100 :] else: pred = pred[-100:] @@ -261,7 +259,7 @@ def is_correct_strict_box( def verify( - solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None + solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None ) -> bool: """Verify if the solution is correct. @@ -283,10 +281,10 @@ def verify( def compute_score( - solution_str: str, - ground_truth: str, - strict_box_verify: bool = False, - pause_tokens_index: Optional[list[int]] = None, + solution_str: str, + ground_truth: str, + strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None, ) -> dict: """Compute the reward score for a solution. @@ -308,51 +306,56 @@ def compute_score( reward = 1.0 if correct else -1.0 acc = correct - return { - "score": reward, - "acc": acc - } + return {"score": reward, "acc": acc} def compute_reward(response, label, extra_info): predict_str = response - eos_token = extra_info['eos_token'] + eos_token = extra_info["eos_token"] if response.endswith(eos_token): response = response[: -len(eos_token)] out = compute_score(response, label) - reward = out['score'] + reward = out["score"] overlong_reward = 0 if extra_info.get("enable_overlong_buffer", None): - overlong_buffer_len = extra_info['overlong_buffer_len'] - expected_len = extra_info['max_response_len'] - overlong_buffer_len - valid_response_length = len(extra_info['tokenizer'](predict_str, return_tensors="pt")["input_ids"].flatten().tolist()) + overlong_buffer_len = extra_info["overlong_buffer_len"] + expected_len = extra_info["max_response_len"] - overlong_buffer_len + valid_response_length = len( + extra_info["tokenizer"](predict_str, return_tensors="pt")["input_ids"].flatten().tolist() + ) exceed_len = valid_response_length - expected_len - overlong_penalty_factor = extra_info['overlong_penalty_factor'] + overlong_penalty_factor = extra_info["overlong_penalty_factor"] overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) reward += overlong_reward - return {'score': reward, 'acc': out['acc']} + return {"score": reward, "acc": out["acc"]} class DapoMathJudgerConfig(BaseModel): - extra_info: dict = Field(default={"score": 1, "format_score": 0, "eos_token": '<|endoftext|>'}) + extra_info: dict = Field(default={"score": 1, "format_score": 0, "eos_token": "<|endoftext|>"}) enable_overlong_buffer: bool max_response_len: Optional[int] = None overlong_buffer_len: Optional[int] = None overlong_penalty_factor: Optional[float] = None tokenizer: Any = None - def __init__(self, enable_overlong_buffer: bool, max_response_len: Optional[int], - overlong_buffer_len: Optional[int], overlong_penalty_factor: Optional[float], tokenizer: Any): + def __init__( + self, + enable_overlong_buffer: bool, + max_response_len: Optional[int], + overlong_buffer_len: Optional[int], + overlong_penalty_factor: Optional[float], + tokenizer: Any, + ): # 初始化基类 super().__init__( enable_overlong_buffer=enable_overlong_buffer, max_response_len=max_response_len, overlong_buffer_len=overlong_buffer_len, overlong_penalty_factor=overlong_penalty_factor, - tokenizer=tokenizer + tokenizer=tokenizer, ) # 根据条件更新 extra_info @@ -361,13 +364,15 @@ def __init__(self, enable_overlong_buffer: bool, max_response_len: Optional[int] assert overlong_buffer_len is not None assert overlong_penalty_factor is not None assert tokenizer is not None - self.extra_info.update({ - "enable_overlong_buffer": enable_overlong_buffer, - "max_response_len": max_response_len, - "overlong_buffer_len": overlong_buffer_len, - "overlong_penalty_factor": overlong_penalty_factor, - "tokenizer": tokenizer, - }) + self.extra_info.update( + { + "enable_overlong_buffer": enable_overlong_buffer, + "max_response_len": max_response_len, + "overlong_buffer_len": overlong_buffer_len, + "overlong_penalty_factor": overlong_penalty_factor, + "tokenizer": tokenizer, + } + ) def build(self): return NativeJudger(reward_func=compute_reward, extra_info=self.extra_info) diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index f808406df..c9c855a20 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -346,7 +346,6 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): log_str = f"Rollout {rollout_idx} Step {i}: " + log_str logger.info(log_str) - def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): self._engine.save_hf(hf_dir, save_dtype) diff --git a/xtuner/v1/rl/loss_fn.py b/xtuner/v1/rl/loss_fn.py index 6efe90d9a..8121b408d 100644 --- a/xtuner/v1/rl/loss_fn.py +++ b/xtuner/v1/rl/loss_fn.py @@ -45,8 +45,8 @@ def check_config(keys_needed: list[str], config: dict[str, Any]) -> None: @register_policy_loss("vanilla") def pg_loss_fn( - logprobs: torch.Tensor, - old_logprobs: torch.Tensor, + log_prob: torch.Tensor, + old_log_prob: torch.Tensor, advantages: torch.Tensor, loss_weights: torch.Tensor, policy_loss_cfg: dict, @@ -54,12 +54,23 @@ def pg_loss_fn( check_config(["cliprange_low", "cliprange_high"], policy_loss_cfg) cliprange_low = policy_loss_cfg["cliprange_low"] cliprange_high = policy_loss_cfg["cliprange_high"] - ratio = (logprobs - old_logprobs.detach()).exp() - advantages = advantages.to(logprobs.dtype) - loss1 = -ratio * advantages - loss2 = -ratio.clamp(1 - cliprange_low, 1 + cliprange_high) * advantages - loss_max = torch.max(loss1, loss2) - loss = (loss_max * loss_weights.to(loss_max.dtype)).sum() + clip_ratio_c = 10.0 + advantages = advantages.to(log_prob.dtype) + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + pg_losses1 = -advantages * ratio + pg_losses2 = -advantages * torch.clamp( + ratio, 1 - cliprange_low, 1 + cliprange_high + ) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum( + pg_losses1, pg_losses2 + ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_losses3 = -advantages * clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + loss = (pg_losses * loss_weights.to(pg_losses.dtype)).sum() return loss diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 266d786c6..b07486c66 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -7,6 +7,7 @@ from shutil import rmtree from typing import cast +import numpy as np import ray import torch from mmengine import load @@ -293,8 +294,9 @@ def fit(self): self.logger.info(f"rollout_idx {rollout_idx} finished, saved trajectories to {trajectory_save_path}") ray.get(self._train_controller.onload.remote(target="all")) self.logger.info("Training controller loaded") - data_batches = self._prepare_train_data(data_groups, self._train_worker_cfg.pack_max_length) + data_batches, data_info = self._prepare_train_data(data_groups, self._train_worker_cfg.pack_max_length) self.logger.info(f"Prepared {len(data_batches)} training data batches") + self.logger.info(f"DataInfo {data_info}") ray.get( self._train_controller.fit.remote( data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx @@ -316,6 +318,11 @@ def fit(self): # TODO: advantage 是在 DataFlow 里算好,还是在 train controller 里算? # 因为可能有根据 advantage 来判断数据能否进 rl 训练的需求。暂时先放在这 def _prepare_train_data(self, data_groups, pack_max_length): + rewards_list = [] + advantages_list = [] + prompt_len_list = [] + response_len_list = [] + data_batches = [] for group in data_groups: prompt = self.tokenizer.apply_chat_template( @@ -323,6 +330,7 @@ def _prepare_train_data(self, data_groups, pack_max_length): ) prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].flatten().tolist() rewards = [data["reward"] for data in group] + rewards_list.extend(rewards) rewards = torch.tensor(rewards, dtype=torch.float32) advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8) @@ -331,6 +339,11 @@ def _prepare_train_data(self, data_groups, pack_max_length): item = group[i]["response_str"] response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() input_ids = prompt_ids + response_ids + + prompt_len_list.append(len(prompt_ids)) + response_len_list.append(len(response_ids)) + advantages_list.extend([advantages[i]] * len(response_ids)) + shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + [-100] if len(input_ids) > pack_max_length: input_ids = input_ids[:pack_max_length] @@ -345,22 +358,65 @@ def _prepare_train_data(self, data_groups, pack_max_length): ) ) random.shuffle(data_batches) - return data_batches + + advantages_list = np.array(advantages_list) + info_dict = { + "batch_size": len(rewards_list), + "rewards/mean": np.mean(rewards_list), + "rewards/min": np.min(rewards_list), + "rewards/max": np.max(rewards_list), + "advantages/mean": np.mean(advantages_list), + "advantages/min": np.min(advantages_list), + "advantages/max": np.max(advantages_list), + "response_len/mean": np.mean(response_len_list), + "response_len/min": np.min(response_len_list), + "response_len/max": np.max(response_len_list), + "response_len/std": np.std(response_len_list), + "prompt_len/mean": np.mean(prompt_len_list), + "prompt_len/min": np.min(prompt_len_list), + "prompt_len/max": np.max(prompt_len_list), + } + return data_batches, info_dict def _save_trajectories(self, data_groups, save_path): - with open(save_path, "w", encoding='utf-8') as f: + rewards = [] + response_len_list = [] + for group in data_groups: + for data in group: + rewards.append(data["reward"]) + response_ids = self.tokenizer.encode(data["response_str"], add_special_tokens=False) + response_len_list.append(len(response_ids)) + + rewards = torch.tensor(rewards) + response_lens = torch.tensor(response_len_list) + + _count = 0 + with open(save_path, "w", encoding="utf-8") as f: + item = { + "reward_mean": rewards.mean().item(), + "reward_std": rewards.std().item(), + "reward_max": rewards.max().item(), + "reward_min": rewards.min().item(), + "response_len_mean": response_lens.mean().item(), + "response_len_std": response_lens.std().item(), + "response_len_max": response_lens.max().item(), + "response_len_min": response_lens.min().item(), + "total_len": len(rewards), + } + json.dump(item, f, ensure_ascii=False, indent=2) + f.write("\n") for group in data_groups: for data in group: - response_ids = self.tokenizer.encode(data["response_str"], add_special_tokens=False) item = { "messages": data["messages"], "response": data["response_str"], - "response_len": len(response_ids), + "response_len": response_len_list[_count], "label": data["reward_model"]["ground_truth"], "reward": data["reward"], } json.dump(item, f, ensure_ascii=False, indent=2) f.write("\n") + _count += 1 def _load_trajectories(self, save_path): data_groups = [] From 6c1d2458ed8f6600e8ed765e404c9e5455cf8769 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 24 Sep 2025 10:56:16 +0800 Subject: [PATCH 12/22] fix --- xtuner/v1/train/rl_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index b07486c66..4b5b61840 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -387,8 +387,8 @@ def _save_trajectories(self, data_groups, save_path): response_ids = self.tokenizer.encode(data["response_str"], add_special_tokens=False) response_len_list.append(len(response_ids)) - rewards = torch.tensor(rewards) - response_lens = torch.tensor(response_len_list) + rewards = torch.tensor(rewards).float() + response_lens = torch.tensor(response_len_list).float() _count = 0 with open(save_path, "w", encoding="utf-8") as f: @@ -452,7 +452,7 @@ def _maybe_save_hf(self): "You meet this error means `load_from` of trainer is not a Huggingface model path." ) - if self.cur_epoch % self._hf_interval != 0 and self.cur_epoch != self.total_epoch: + if (self.cur_epoch+1) % self._hf_interval != 0 and (self.cur_epoch+1) != self.total_epoch: return save_hf_path = self.exp_dir / f"hf-{self.cur_epoch}" From 43ce367172b3c693bbdf1dda74f52772f52b5cf8 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 24 Sep 2025 11:50:00 +0800 Subject: [PATCH 13/22] fix --- xtuner/v1/train/rl_trainer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 4b5b61840..ebe19e9bc 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -223,7 +223,7 @@ def __init__( temperature=1.0, do_sample=True, max_tokens=dataflow_config.sample_params.max_tokens, - top_k=1, + top_k=0, ) self._eval_step = evaluator_config.evaluate_step else: @@ -282,7 +282,7 @@ def fit(self): scores, eval_data_groups = ray.get( self._evaluator.run.remote(return_samples=True, sample_params=self._evaluator_sample_params) ) - trajectory_save_path = self.exp_dir / "initial_trajectory.jsonl" + trajectory_save_path = self.exp_dir / "eval_0_trajectory.jsonl" self._save_trajectories(eval_data_groups, trajectory_save_path) self.logger.info(f"Initial rollout evaluate scores {scores} and start training") for rollout_idx in range(1, self._rollout_steps + 1): @@ -311,7 +311,10 @@ def fit(self): ray.get(self._rollout_env_controller.onload_kvcache.remote()) # evaluate if self._enable_evaluate and self._evaluator and rollout_idx % self._eval_step == 0: - scores = ray.get(self._evaluator.run.remote(sample_params=self._evaluator_sample_params)) + scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True, + sample_params=self._evaluator_sample_params)) + trajectory_save_path = self.exp_dir / f"eval_{rollout_idx}_trajectory.jsonl" + self._save_trajectories(eval_data_groups, trajectory_save_path) self.logger.info(f"evaluate idx {rollout_idx} scores {scores}") self._cur_epoch += 1 From 45f8a930d42cfb68e9efcc7bff3bf41c73fa9535 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Sun, 28 Sep 2025 19:56:27 +0800 Subject: [PATCH 14/22] update --- ci/scripts/test_dapo_sglang.sh | 38 ++++++++ ci/scripts/test_dapo_trainer.py | 14 ++- xtuner/v1/ray/environment/base_env.py | 4 + xtuner/v1/ray/rollout/sglang.py | 118 ++++++++++++++++--------- xtuner/v1/rl/base/worker.py | 121 +++++++++++++++++++++----- xtuner/v1/train/rl_trainer.py | 1 + 6 files changed, 232 insertions(+), 64 deletions(-) create mode 100644 ci/scripts/test_dapo_sglang.sh diff --git a/ci/scripts/test_dapo_sglang.sh b/ci/scripts/test_dapo_sglang.sh new file mode 100644 index 000000000..12d8d3962 --- /dev/null +++ b/ci/scripts/test_dapo_sglang.sh @@ -0,0 +1,38 @@ +set -ex + +export XTUNER_USE_SGLANG=1 # 最好训练用 fa3,暂时是 fa2 +export PYTHONPATH=/mnt/shared-storage-user/huanghaian/code/lmdeploy/:$PYTHONPATH +export UVICORN_LOG_LEVEL="CRITICAl" + +# 不支持 +# export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' + +OUTPUT_DIR='work_dirs/dapo_math_7B_newlmdeploy_nogroup_sglang' +if [ ! -d "$OUTPUT_DIR" ]; then + mkdir -p "$OUTPUT_DIR" +fi + +export ROLLOUT_MODEL_PATH='/mnt/shared-storage-user/llmrazor-share/model/Qwen2.5-Math-7B' +export ROLLOUT_DATA_PATH="/mnt/shared-storage-user/huanghaian/code/verl/data/dapo_math/dapo-math-17k.jsonl" +export ROLLOUT_TEST_DATA_PATH="/mnt/shared-storage-user/huanghaian/code/verl/data/dapo_math/aime-2024.jsonl" + +ray stop --force + +# --max-concurrent 如果开大会 oom +python ci/scripts/test_dapo_trainer.py \ + --total-epochs 1 \ + --work-dir "$OUTPUT_DIR" \ + --num-workers 8 \ + --gpus-per-node 8 \ + --rollout-global-batch-size 512 \ + --train-optimizer-steps 16 \ + --max-concurrent 32 \ + --prompt-repeat-k 16 \ + --pack-max-length 32768 \ + --max-prompt-length 2048 \ + --max-response-length 8192 \ + --optimizer-disable-foreach \ + --enable-evaluate \ + --evaluate-step 5 \ + --hf-interval 50 \ + 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" diff --git a/ci/scripts/test_dapo_trainer.py b/ci/scripts/test_dapo_trainer.py index 25dce262c..e5c651641 100644 --- a/ci/scripts/test_dapo_trainer.py +++ b/ci/scripts/test_dapo_trainer.py @@ -81,17 +81,29 @@ def main(args): num_workers=args.num_workers, cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB ) + + if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": + backend = 'sglang' + launch_server_method = 'multiprocessing' + elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": + backend = 'vllm' + launch_server_method = 'ray' + else: + backend = 'lmdeploy' + launch_server_method = 'ray' rollout_config = RolloutConfig( env="test_env", model_path=args.model_path, model_name=os.path.basename(args.model_path).lower(), tokenizer_path=args.model_path, rollout_cross_node_comm=False, - tensor_parallel_size=2, + tensor_parallel_size=1, # TODO: 暂时写死 expert_parallel_size=1, gpus_per_node=args.gpus_per_node, # gpu: 8, npu: 16 dtype="bfloat16", skip_load_weights=False, + backend=backend, + launch_server_method=launch_server_method, ) dataflow_config = DataFlowConfig( env="test", diff --git a/xtuner/v1/ray/environment/base_env.py b/xtuner/v1/ray/environment/base_env.py index 61d968d87..fd29f47a5 100644 --- a/xtuner/v1/ray/environment/base_env.py +++ b/xtuner/v1/ray/environment/base_env.py @@ -56,6 +56,10 @@ def init_rollout_controller(self, placement_group: Any, rollout_cfg: Any): from xtuner.v1.ray.rollout import vLLMWorker rollout_workers_map = AutoAcceleratorWorkers.from_placement_group(vLLMWorker, rollout_cfg, placement_group) + elif rollout_cfg.backend == "sglang": + from xtuner.v1.ray.rollout import SGLangWorker + + rollout_workers_map = AutoAcceleratorWorkers.from_placement_group(SGLangWorker, rollout_cfg, placement_group) else: raise NotImplementedError(f"Rollout backend '{rollout_cfg.backend}' is not supported.") diff --git a/xtuner/v1/ray/rollout/sglang.py b/xtuner/v1/ray/rollout/sglang.py index d3684e39f..b75189dd2 100644 --- a/xtuner/v1/ray/rollout/sglang.py +++ b/xtuner/v1/ray/rollout/sglang.py @@ -1,11 +1,10 @@ import os from typing import Any, Dict, List, Union - +import requests import ray from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs -from starlette.background import BackgroundTask -from starlette.responses import StreamingResponse +from urllib3.exceptions import NewConnectionError from xtuner.v1.ray.config import RolloutConfig @@ -15,43 +14,73 @@ @ray.remote class SGLangWorker(RolloutWorker): def __init__( - self, - config: RolloutConfig, - rank: int, - master_addr: str, - master_port: int, - world_size: int, - accelerator: str = "GPU", + self, + config: RolloutConfig, + rank: int, + master_addr: str, + master_port: int, + world_size: int, + accelerator: str = "GPU", ): super().__init__(config, rank, master_addr, master_port, world_size, accelerator) self.server_func = launch_server self.endpoints["health_generate"] = "health_generate" - self.endpoints["generate"] = "generate" + self.endpoints["generate"] = "v1/chat/completions" + self.api_keys = self.config.api_key + self.model_name = self.config.model_name async def _create_request( - self, - url: str, - prompt: Union[str, List[Dict[str, Any]]], - tools: List, - tool_choice: str, - sample_params: dict, - extra_params: dict, + self, + url: str, + prompt: Union[str, List[Dict[str, Any]]], + tools: List, + tool_choice: str, + sample_params: dict, + extra_params: dict, ): - # default params - sample_params["max_new_tokens"] = sample_params.get("max_tokens", 128) - del sample_params["max_tokens"] - payload = {"stream": True, "sampling_params": sample_params, "text": prompt} - - if extra_params: - payload.update(extra_params) - + sample_params['top_k'] = -1 # TODO: 暂时写死 + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_keys}", # 如果需要鉴权 + } + payload = { + "model": self.model_name, + "messages": prompt, + "stream": True, + } + payload.update(sample_params) + payload.update(extra_params) req = self.client.build_request( "POST", url, + headers=headers, json=payload, ) r = await self.client.send(req, stream=True) - return StreamingResponse(r.aiter_text(), background=BackgroundTask(r.aclose)) + return r + + def _make_request(self, endpoint: str, payload=None): + # TODO: 支持 tp + url = f"{self.server_url}/{endpoint}" + response = requests.post(url, json=payload or {}) + response.raise_for_status() + return response.json() + + def flush_cache(self): + """Flush the cache of the server.""" + # TODO: 支持 tp + # flush cache will not return status_code 200 when there are pending requests + while True: + try: + response = requests.get(f"{self.server_url}/flush_cache") + if response.status_code == 200: + break + except NewConnectionError as e: + raise e + except Exception as e: + print(f"Error flushing cache: {e}") + continue def get_logprobs(self, input_ids, sampling_params): return self._make_request( @@ -59,41 +88,50 @@ def get_logprobs(self, input_ids, sampling_params): {"input_ids": input_ids, "sampling_params": sampling_params, "stream": False, "return_logprob": True}, ) - def sleep(self, level=1): + def offload(self): + """Offloads the model weights and KV cache.""" + self.flush_cache() return self._make_request("release_memory_occupation") - def wake_up(self): - return self._make_request("resume_memory_occupation") + def onload_weights(self): + """Onloads the model weights by waking up the model.""" + return self._make_request("resume_memory_occupation", {"tags": ["weights"]}) + + def onload_kvcache(self): + return self._make_request("resume_memory_occupation", {"tags": ["kv_cache"]}) def pause_generation(self): - return self._make_request("pause_generation") + pass + # return self._make_request("pause_generation") def continue_generation(self): - return self._make_request("continue_generation") - - def shutdown(self): pass + # return self._make_request("continue_generation") - def update_weights(self, ipc_handles): + def shutdown(self): pass def reset_prefix_cache(self): pass - def _transform_rollout_config_to_server_configs(self, infer_config): + def _transform_rollout_config_to_server_configs(self): # remove the CUDA_VISIBLE_DEVICES set by ray and use base_gpu_id os.environ.pop("CUDA_VISIBLE_DEVICES", None) - sglang_server_args = ServerArgs(model_path=infer_config.model_path) + sglang_server_args = ServerArgs(model_path=self.config.model_path) sglang_server_args.host = self.host sglang_server_args.port = self.server_port sglang_server_args.nccl_port = self.nccl_port sglang_server_args.dist_init_addr = self.dist_init_addr - sglang_server_args.base_gpu_id = self.rank % infer_config.gpus_per_node + sglang_server_args.base_gpu_id = self.rank % self.config.gpus_per_node sglang_server_args.gpu_id_step = 1 - sglang_server_args.nnodes = max(1, infer_config.tensor_parallel_size // infer_config.gpus_per_node) + sglang_server_args.nnodes = max(1, self.config.tensor_parallel_size // self.config.gpus_per_node) + sglang_server_args.skip_server_warmup = True + sglang_server_args.tp_size = self.config.tensor_parallel_size + sglang_server_args.mem_fraction_static = 0.7 # 关键 + sglang_server_args.enable_memory_saver = True # 关键,否则显存释放不了 if sglang_server_args.nnodes > 1: - sglang_server_args.node_rank = self.rank // infer_config.gpus_per_node + sglang_server_args.node_rank = self.rank // self.config.gpus_per_node else: sglang_server_args.node_rank = 0 return sglang_server_args diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index c9c855a20..630e15421 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -34,6 +34,18 @@ logger = get_logger() +try: + from sglang.srt.patch_torch import monkey_patch_torch_reductions + from sglang.srt.utils import MultiprocessingSerializer + + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket + + use_flattened_tensor_bucket = True +except: + use_flattened_tensor_bucket = False + +print(f'[SGLang] use_flattened_tensor_bucket: {use_flattened_tensor_bucket}') + def serialize_state_dict(state_dict: dict) -> str: """Serialize state dict to str. @@ -391,9 +403,13 @@ def update_rollout_info( self.rollout_cfg_info["tp"] = tp self.rollout_cfg_info["ep"] = ep self.rollout_cfg_info["api_key"] = rollout_config.api_key - self.rollout_cfg_info["backend"] = (rollout_config.extra_rollout_config or dict()).get( - "lmdeploy_backend", "pytorch" - ) + + if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": + self.rollout_cfg_info["backend"] = 'sglang' + else: + self.rollout_cfg_info["backend"] = (rollout_config.extra_rollout_config or dict()).get( + "lmdeploy_backend", "pytorch" + ) def update_weights(self): """Update the model weights.""" @@ -435,8 +451,14 @@ def get_params(tensor_list, name_list, save_dtype): name_list.append(name) tensor_list.append((local_tensor, load_spec)) fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) - state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - self.request_update_params(state_dict) + if self.rollout_cfg_info["backend"] == "pytorch": + state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) + self.request_update_params(state_dict) + else: + state_dict = [] + for name, tensor in zip(name_list, fsdp_unshard_tensor_list): + state_dict.append((name, tensor)) + self.request_update_params(state_dict) tensor_list = [] name_list = [] @@ -453,10 +475,17 @@ def get_params(tensor_list, name_list, save_dtype): tensor_list = [(local_tensor, load_spec)] name_list = [name] fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) - state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - self.request_update_params(state_dict) - - self.request_update_params({}, finished=True) + if self.rollout_cfg_info["backend"] == "pytorch": + state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) + self.request_update_params(state_dict) + else: + state_dict = [] + for name, tensor in zip(name_list, fsdp_unshard_tensor_list): + state_dict.append((name, tensor)) + self.request_update_params(state_dict) + + if self.rollout_cfg_info["backend"] == "pytorch": + self.request_update_params({}, finished=True) dist.barrier() logger.info(f"update weights time: {time.time() - time1}") @@ -641,31 +670,77 @@ def get_params(tensor_list, name_list, save_dtype): # return def request_update_params(self, state_dict, finished=False): + """Send a request to update the parameters on the rollout workers. + + This method serializes the state dictionary and sends it to the + appropriate rollout worker via an HTTP request. + + Args: + state_dict (dict | list): The state dictionary containing the model + parameters to update. + finished (bool): A flag indicating whether this is the final + batch of updates. Defaults to False. + """ cpu_mesh = self.rollout_device_mesh["engine_parallel"] cpu_group = cpu_mesh.get_group() head_rank = cpu_mesh.mesh[0].item() - if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1: - serialized_data = [None] * self.rollout_cfg_info["tp"] - tmp_serialized_data = serialize_state_dict(state_dict) - dist.gather_object( - tmp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) + if self.rollout_cfg_info["backend"] == "pytorch": + # TODO(chenchiyu): remove lmdeploy related code + from lmdeploy.utils import serialize_state_dict + + if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1: + serialized_data = [None] * self.rollout_cfg_info["tp"] + dist.gather_object( + serialize_state_dict(state_dict), + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + elif self.rollout_cfg_info["backend"] == "pytorch": + serialized_data = serialize_state_dict(state_dict) + else: + # for turbomind backend, only head_rank should serialize data + serialized_data = serialize_state_dict(state_dict) if dist.get_rank() == head_rank else None else: - serialized_data = serialize_state_dict(state_dict) + # sglang + monkey_patch_torch_reductions() + if use_flattened_tensor_bucket: + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) + metadata = flattened_tensor_bucket.get_metadata() + + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": metadata, + } + serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + else: + serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) + + # TODO: 支持 tp + serialized_data = [serialized_data] if dist.get_rank() == head_rank: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.rollout_cfg_info['api_key']}", } - data = dict(serialized_named_tensors=serialized_data, finished=finished) - response = requests.post( - f"{self.rollout_url}/{self.endpoints['update_weights']}", headers=headers, json=data - ) + if self.rollout_cfg_info["backend"] == "sglang": + payload = { + "serialized_named_tensors": serialized_data, + "flush_cache": False, + } + if use_flattened_tensor_bucket: + payload["load_format"] = "flattened_bucket" + + url = f"{self.rollout_url}/update_weights_from_tensor" + response = requests.post(url, json=payload or {}) + response.raise_for_status() + else: + data = dict(serialized_named_tensors=serialized_data, finished=finished) + response = requests.post( + f"{self.rollout_url}/{self.endpoints['update_weights']}", headers=headers, json=data + ) assert response.status_code == 200, f"response.status_code = {response.status_code}" if finished: diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index ebe19e9bc..6cd1f6157 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -292,6 +292,7 @@ def fit(self): trajectory_save_path = self.exp_dir / f"rollout_idx_{rollout_idx}_trajectory.jsonl" self._save_trajectories(data_groups, trajectory_save_path) self.logger.info(f"rollout_idx {rollout_idx} finished, saved trajectories to {trajectory_save_path}") + time.sleep(3) ray.get(self._train_controller.onload.remote(target="all")) self.logger.info("Training controller loaded") data_batches, data_info = self._prepare_train_data(data_groups, self._train_worker_cfg.pack_max_length) From e63562d860fd339e03555f3300633ed1ea6577f8 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Sun, 28 Sep 2025 19:58:11 +0800 Subject: [PATCH 15/22] fix print --- xtuner/v1/rl/grpo/loss.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xtuner/v1/rl/grpo/loss.py b/xtuner/v1/rl/grpo/loss.py index 34dc9602a..e8d9ced41 100644 --- a/xtuner/v1/rl/grpo/loss.py +++ b/xtuner/v1/rl/grpo/loss.py @@ -144,7 +144,10 @@ def loss_fn( policy_loss_weight, self.loss_cfg.policy_loss_cfg, ) + + # 只看响应部分 ratio = (logprobs - old_logprobs.detach()).exp() + ratio = ratio * policy_loss_weight.float() if self.loss_cfg.use_kl_loss: ref_logprobs = loss_kwargs.ref_logprobs From dfafbd5c029397a464e039aa9408a74c60b031a3 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Sun, 28 Sep 2025 21:12:47 +0800 Subject: [PATCH 16/22] fix print --- .dev_scripts/mypy_entrypoint.sh | 2 +- ci/scripts/test_dapo.sh | 2 +- ci/scripts/test_dapo_trainer.py | 6 ++++-- xtuner/v1/rl/grpo/loss.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.dev_scripts/mypy_entrypoint.sh b/.dev_scripts/mypy_entrypoint.sh index 9904c17d2..50c174f8a 100755 --- a/.dev_scripts/mypy_entrypoint.sh +++ b/.dev_scripts/mypy_entrypoint.sh @@ -1,4 +1,4 @@ -#! /usr/bin/bash +#! /bin/bash set -e export PYTHONPATH=$(dirname $0)/.. diff --git a/ci/scripts/test_dapo.sh b/ci/scripts/test_dapo.sh index 6e0c5ac54..23d9870a2 100644 --- a/ci/scripts/test_dapo.sh +++ b/ci/scripts/test_dapo.sh @@ -24,7 +24,7 @@ python ci/scripts/test_dapo_trainer.py \ --gpus-per-node 8 \ --rollout-global-batch-size 512 \ --train-optimizer-steps 16 \ - --max-concurrent 64 \ + --max-concurrent 32 \ --prompt-repeat-k 16 \ --pack-max-length 32768 \ --max-prompt-length 2048 \ diff --git a/ci/scripts/test_dapo_trainer.py b/ci/scripts/test_dapo_trainer.py index e5c651641..e2d046c44 100644 --- a/ci/scripts/test_dapo_trainer.py +++ b/ci/scripts/test_dapo_trainer.py @@ -39,8 +39,10 @@ MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] -os.environ['XTUNER_USE_FA3'] = "1" +os.environ['XTUNER_USE_FA3'] = "1" +if os.environ['XTUNER_USE_FA3'] == "1": + from flash_attn_interface import flash_attn_3_cuda def parse_args(): parser = argparse.ArgumentParser(description="VLLM Rollout Test Script") @@ -97,7 +99,7 @@ def main(args): model_name=os.path.basename(args.model_path).lower(), tokenizer_path=args.model_path, rollout_cross_node_comm=False, - tensor_parallel_size=1, # TODO: 暂时写死 + tensor_parallel_size=1, # TODO: sglang 暂时写死 expert_parallel_size=1, gpus_per_node=args.gpus_per_node, # gpu: 8, npu: 16 dtype="bfloat16", diff --git a/xtuner/v1/rl/grpo/loss.py b/xtuner/v1/rl/grpo/loss.py index e8d9ced41..f8de6e867 100644 --- a/xtuner/v1/rl/grpo/loss.py +++ b/xtuner/v1/rl/grpo/loss.py @@ -147,7 +147,7 @@ def loss_fn( # 只看响应部分 ratio = (logprobs - old_logprobs.detach()).exp() - ratio = ratio * policy_loss_weight.float() + ratio = ratio * (shifted_labels != self.loss_cfg.ignore_idx).float() if self.loss_cfg.use_kl_loss: ref_logprobs = loss_kwargs.ref_logprobs From 284ff29929da9613a7b40800f67c7bb2efa744b8 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Mon, 29 Sep 2025 11:41:04 +0800 Subject: [PATCH 17/22] support generate rollout of sglang --- ci/scripts/test_dapo_sglang.sh | 1 + xtuner/v1/ray/rollout/sglang.py | 36 ++++++++++++++++------ xtuner/v1/ray/rollout/worker.py | 54 ++++++++++++++++++++++++--------- 3 files changed, 67 insertions(+), 24 deletions(-) diff --git a/ci/scripts/test_dapo_sglang.sh b/ci/scripts/test_dapo_sglang.sh index 12d8d3962..11217fa80 100644 --- a/ci/scripts/test_dapo_sglang.sh +++ b/ci/scripts/test_dapo_sglang.sh @@ -3,6 +3,7 @@ set -ex export XTUNER_USE_SGLANG=1 # 最好训练用 fa3,暂时是 fa2 export PYTHONPATH=/mnt/shared-storage-user/huanghaian/code/lmdeploy/:$PYTHONPATH export UVICORN_LOG_LEVEL="CRITICAl" +#export ID_INPUT_OUTPUT=1 # 不支持 # export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' diff --git a/xtuner/v1/ray/rollout/sglang.py b/xtuner/v1/ray/rollout/sglang.py index b75189dd2..3e621be1f 100644 --- a/xtuner/v1/ray/rollout/sglang.py +++ b/xtuner/v1/ray/rollout/sglang.py @@ -7,8 +7,9 @@ from urllib3.exceptions import NewConnectionError from xtuner.v1.ray.config import RolloutConfig - +from transformers import AutoTokenizer from .worker import RolloutWorker +import os @ray.remote @@ -25,7 +26,11 @@ def __init__( super().__init__(config, rank, master_addr, master_port, world_size, accelerator) self.server_func = launch_server self.endpoints["health_generate"] = "health_generate" - self.endpoints["generate"] = "v1/chat/completions" + if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': + self.endpoints["generate"] = "generate" + self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) + else: + self.endpoints["generate"] = "v1/chat/completions" self.api_keys = self.config.api_key self.model_name = self.config.model_name @@ -44,13 +49,26 @@ async def _create_request( "Content-Type": "application/json", "Authorization": f"Bearer {self.api_keys}", # 如果需要鉴权 } - payload = { - "model": self.model_name, - "messages": prompt, - "stream": True, - } - payload.update(sample_params) - payload.update(extra_params) + if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': + payload = {"model": self.model_name, "stream": True, "return_logprob": True} + text_prompt = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + prompt_token_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] + payload["input_ids"] = prompt_token_ids + + new_sample_params = {"max_new_tokens": sample_params['max_tokens'], + "temperature": sample_params['temperature'], + "top_p": sample_params['top_p'], + "top_k": sample_params['top_k'] + } + payload['sampling_params'] = new_sample_params + else: + payload = { + "model": self.model_name, + "messages": prompt, + "stream": True, + } + payload.update(sample_params) + payload.update(extra_params) req = self.client.build_request( "POST", url, diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py index 8560c3a25..a20654d00 100644 --- a/xtuner/v1/ray/rollout/worker.py +++ b/xtuner/v1/ray/rollout/worker.py @@ -5,7 +5,7 @@ import uuid from abc import abstractmethod from typing import Any, Callable, Dict, List, Optional, Union - +import os import httpx import ray import requests # type: ignore[import-untyped] @@ -80,7 +80,8 @@ class RolloutRequest(BaseModel): class RolloutResponse(BaseModel): response: str = "" - logprobs: float = 0.0 + logprobs: list = [] + response_ids: list = [] finish_reason: str = "" reasoning_content: str = "" usage: dict = Field(default_factory=dict) @@ -131,6 +132,10 @@ def __init__( self.server_process: Optional[multiprocessing.Process] = None self.logger = get_logger() + if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) + def init_dist_port(self): """Initialize distributed communication ports. @@ -367,7 +372,7 @@ async def rollout_task( if not chunk.startswith("data:"): continue try: - chunk_data_str = chunk[len("data:") :].strip() + chunk_data_str = chunk[len("data:"):].strip() if self.paused or chunk_data_str == "[DONE]": finish_reason = "paused" if self.paused else finish_reason break @@ -375,14 +380,25 @@ async def rollout_task( continue chunk_data = json.loads(chunk_data_str) - delta_content = chunk_data["choices"][0]["delta"].get("content") - last_trajectory = last_trajectory + delta_content if delta_content else last_trajectory - finish_reason = chunk_data["choices"][0].get("finish_reason") - - # todo(@duanyanhui): remove appending stop tokens manually after lmdeploy support return stop_token_ids. - if finish_reason == "stop": - assert len(sample_params["stops"]) == 1 - last_trajectory += sample_params["stops"][0] + if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': + # TODO: 不太懂,好像是一个假的流式,每次返回的都会包括之前的? + last_trajectory = chunk_data['text'] + if "output_token_logprobs" in chunk_data["meta_info"]: + new_response_tokens = [item[1] for item in chunk_data["meta_info"]["output_token_logprobs"]] + new_response_log_probs = [item[0] for item in chunk_data["meta_info"]["output_token_logprobs"]] + finish_reason = chunk_data["meta_info"].get("finish_reason") + if finish_reason is not None: + assert isinstance(finish_reason, dict) + finish_reason = finish_reason['type'] + else: + delta_content = chunk_data["choices"][0]["delta"].get("content") + last_trajectory = last_trajectory + delta_content if delta_content else last_trajectory + finish_reason = chunk_data["choices"][0].get("finish_reason") + + # todo(@duanyanhui): remove appending stop tokens manually after lmdeploy support return stop_token_ids. + if finish_reason == "stop": + assert len(sample_params["stops"]) == 1 + last_trajectory += sample_params["stops"][0] except json.JSONDecodeError as e: self.logger.error(f"JSON decode error for chunk in request {uid}: {chunk}, error: {e}") @@ -395,10 +411,18 @@ async def rollout_task( f"Unexpected finish_reason: {finish_reason}" ) - rollout_response = RolloutResponse( - response=last_trajectory, - finish_reason=finish_reason, - ) + if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': + rollout_response = RolloutResponse( + logprobs=new_response_log_probs, + response_ids=new_response_tokens, + response=last_trajectory, + finish_reason=finish_reason, + ) + else: + rollout_response = RolloutResponse( + response=last_trajectory, + finish_reason=finish_reason, + ) return rollout_response except httpx.RequestError as e: From 14a2882c616b08b0c9b5d30929f5d16729a65b77 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Mon, 29 Sep 2025 12:32:42 +0800 Subject: [PATCH 18/22] update --- xtuner/v1/datasets/data_item.py | 1 + xtuner/v1/ray/dataflow/replay_buffer.py | 29 ++++++++++++-------- xtuner/v1/ray/environment/single_turn_env.py | 2 ++ xtuner/v1/train/rl_trainer.py | 18 ++++++++---- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/xtuner/v1/datasets/data_item.py b/xtuner/v1/datasets/data_item.py index 611e565ac..9dd30531a 100644 --- a/xtuner/v1/datasets/data_item.py +++ b/xtuner/v1/datasets/data_item.py @@ -43,6 +43,7 @@ class RLTextDataItem(CacheItem, total=False): reward: float | None num_return_tokens: int | None response_ids: list[int] | None + logprobs: list | None response_str: str | None state: str retry_times: int diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 13bf253ed..7fc00bcc0 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -179,9 +179,9 @@ def sample(self, env: str, enable_partial_rollout: int, prompt_repeat_k: int) -> List[RLTextDataItem]: A list of sampled data items. """ if ( - enable_partial_rollout - and "unfinished" in self.storage._rollout_states - and len(self.storage._rollout_states["unfinished"]) > 0 + enable_partial_rollout + and "unfinished" in self.storage._rollout_states + and len(self.storage._rollout_states["unfinished"]) > 0 ): return self.sample_from_unfinished_buffer() else: @@ -230,10 +230,10 @@ def replaymeta2dataitem(self, replay_meta: ReplayMeta, messages=None, input_ids= ) input_ids = input_ids or [] num_tokens = len(input_ids) - response_str = ( + response_dict = ( ray.get(replay_meta.observation_refs[0]) if replay_meta.observation_refs and len(replay_meta.observation_refs) > 0 - else "" + else {} ) return RLTextDataItem( env=replay_meta.env, @@ -242,7 +242,9 @@ def replaymeta2dataitem(self, replay_meta: ReplayMeta, messages=None, input_ids= messages=messages, input_ids=input_ids, num_tokens=num_tokens, - response_str=response_str, + response_str=response_dict['response_str'], + logprobs=response_dict['logprobs'], + response_ids=response_dict['response_ids'], reward_model={"ground_truth": replay_meta.ground_truth}, reward=replay_meta.rewards[-1] if replay_meta.rewards and len(replay_meta.rewards) > 0 else None, state=replay_meta.state, @@ -263,7 +265,10 @@ def dataitem2replaymeta(self, data_item: RLTextDataItem) -> ReplayMeta: action_id=data_item["prompt_id"], action_refs=[ray.put(data_item["messages"])] if "messages" in data_item else [], observation_ids=[uuid.uuid4().int], - observation_refs=[ray.put(data_item["response_str"])] if "response_str" in data_item else [], + observation_refs=[ray.put({'response_str': data_item["response_str"], + 'logprobs': data_item["logprobs"], + 'response_ids': data_item["response_ids"] + })] if "response_str" in data_item else [], observation_versions=[1], state=data_item["state"] if "state" in data_item else "", ground_truth=data_item["reward_model"]["ground_truth"], @@ -372,7 +377,7 @@ def dump(self, file_path: str): actions_list.append(ray.get(ref)) replay_meta.action_refs = actions_list if replay_meta.observation_refs and all( - isinstance(ref, ray.ObjectRef) for ref in replay_meta.observation_refs + isinstance(ref, ray.ObjectRef) for ref in replay_meta.observation_refs ): observations_list = [] for ref in replay_meta.observation_refs: @@ -457,8 +462,8 @@ class ReplayBuffer: learning.""" def __init__( - self, - config: ReplayBufferConfig, + self, + config: ReplayBufferConfig, ): """Initializes the ReplayBuffer actor. @@ -517,8 +522,8 @@ def sample(self, env, enable_partial_rollout: int, prompt_repeat_k: int): return self.sampler.sample(env, enable_partial_rollout, prompt_repeat_k) def get_samples( - self, - global_batch_size: int, + self, + global_batch_size: int, ): """Gets a batch of finished samples from the storage. diff --git a/xtuner/v1/ray/environment/single_turn_env.py b/xtuner/v1/ray/environment/single_turn_env.py index d5737abbf..c8ecdab7f 100644 --- a/xtuner/v1/ray/environment/single_turn_env.py +++ b/xtuner/v1/ray/environment/single_turn_env.py @@ -53,6 +53,8 @@ async def generate(self, group_samples: List[RLTextDataItem], sample_params: Non for i in range(len(group_samples)): group_samples[i]["response_str"] = response[i].response group_samples[i]["state"] = response[i].finish_reason + group_samples[i]["logprobs"] = response[i].logprobs + group_samples[i]["response_ids"] = response[i].response_ids return group_samples diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 6cd1f6157..f27d124da 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -329,10 +329,13 @@ def _prepare_train_data(self, data_groups, pack_max_length): data_batches = [] for group in data_groups: - prompt = self.tokenizer.apply_chat_template( - group[0]["messages"], add_generation_prompt=True, tokenize=False - ) - prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].flatten().tolist() + text_prompt = self.tokenizer.apply_chat_template(group[0]["messages"], tokenize=False, add_generation_prompt=True) + prompt_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"].flatten().tolist() + + # prompt = self.tokenizer.apply_chat_template( + # group[0]["messages"], add_generation_prompt=True, tokenize=False + # ) + # prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].flatten().tolist() rewards = [data["reward"] for data in group] rewards_list.extend(rewards) rewards = torch.tensor(rewards, dtype=torch.float32) @@ -341,7 +344,12 @@ def _prepare_train_data(self, data_groups, pack_max_length): prompt_repeat_k = len(group) for i in range(prompt_repeat_k): item = group[i]["response_str"] - response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() + if 'response_ids' in group[i] and group[i]['response_ids'] is not None: + response_ids = group[i]['response_ids'] + if isinstance(response_ids, torch.Tensor): + response_ids = response_ids.flatten().tolist() + else: + response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() input_ids = prompt_ids + response_ids prompt_len_list.append(len(prompt_ids)) From 089a2dbce9bfc8b515842c6fd4904cd016a2c3d8 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Mon, 29 Sep 2025 14:32:13 +0800 Subject: [PATCH 19/22] add logprob --- ci/scripts/test_dapo_sglang.sh | 3 ++- xtuner/v1/ray/rollout/sglang.py | 5 ++++- xtuner/v1/rl/base/controller.py | 24 ++++++++++++++++++++++++ xtuner/v1/rl/base/worker.py | 3 +++ xtuner/v1/train/rl_trainer.py | 15 +++++++++++++-- 5 files changed, 46 insertions(+), 4 deletions(-) diff --git a/ci/scripts/test_dapo_sglang.sh b/ci/scripts/test_dapo_sglang.sh index 11217fa80..0ee5a1c66 100644 --- a/ci/scripts/test_dapo_sglang.sh +++ b/ci/scripts/test_dapo_sglang.sh @@ -3,7 +3,7 @@ set -ex export XTUNER_USE_SGLANG=1 # 最好训练用 fa3,暂时是 fa2 export PYTHONPATH=/mnt/shared-storage-user/huanghaian/code/lmdeploy/:$PYTHONPATH export UVICORN_LOG_LEVEL="CRITICAl" -#export ID_INPUT_OUTPUT=1 +export ID_INPUT_OUTPUT=1 # 不支持 # export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' @@ -36,4 +36,5 @@ python ci/scripts/test_dapo_trainer.py \ --enable-evaluate \ --evaluate-step 5 \ --hf-interval 50 \ + --evaluate-ratio 1 \ 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" diff --git a/xtuner/v1/ray/rollout/sglang.py b/xtuner/v1/ray/rollout/sglang.py index 3e621be1f..1a55e9a08 100644 --- a/xtuner/v1/ray/rollout/sglang.py +++ b/xtuner/v1/ray/rollout/sglang.py @@ -58,7 +58,10 @@ async def _create_request( new_sample_params = {"max_new_tokens": sample_params['max_tokens'], "temperature": sample_params['temperature'], "top_p": sample_params['top_p'], - "top_k": sample_params['top_k'] + "top_k": sample_params['top_k'], + "no_stop_trim": True, + "skip_special_tokens": False, + "spaces_between_special_tokens":False, } payload['sampling_params'] = new_sample_params else: diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index 5c0f3e11b..b409fc3c8 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -13,6 +13,7 @@ class ColateItem(TypedDict): seq_ctx: SequenceContext shifted_labels: torch.Tensor advantage: float + rollout_logprobs: torch.Tensor | None @ray.remote @@ -71,6 +72,11 @@ def _packing(self, data_batches, pack_max_length): seq_ctx_list = [data_batches[i]["seq_ctx"] for i in indices] label_list = [data_batches[i]["shifted_labels"] for i in indices] advantage_list = [data_batches[i]["advantage"] for i in indices] + + rollout_logprobs_list = None + if 'rollout_logprobs' in data_batches[0] and data_batches[0]['rollout_logprobs'] is not None: + rollout_logprobs_list = [data_batches[i]['rollout_logprobs'] for i in indices] + if pad_len > 0: # Reduce the attn calculation time by using multiple short sequence packs pad_tokens = tuple( @@ -95,6 +101,12 @@ def _packing(self, data_batches, pack_max_length): [-100] * math.ceil(pad_len / 1024) ) # can be any number, pad tokens are excluded from the calculation of the loss function. + if rollout_logprobs_list is not None: + pad_rollout_logprobs = torch.zeros( + 1, pad_len, dtype=data_batches[0]['rollout_logprobs'].dtype, device=data_batches[0]['shifted_labels'].device + ) + rollout_logprobs_list.append(pad_rollout_logprobs) + seq_ctx = SequenceContext.pack(seq_ctx_list) shifted_labels = torch.cat(label_list, dim=1) # (1, max_len) advantages = torch.tensor(advantage_list).float().unsqueeze(0) # (1, num_samples) @@ -102,11 +114,16 @@ def _packing(self, data_batches, pack_max_length): num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] advantages = torch.repeat_interleave(advantages, num_tokens, dim=1) # (1, max_len) + rollout_logprobs = None + if rollout_logprobs_list is not None: + rollout_logprobs = torch.cat(rollout_logprobs_list, dim=1) # (1, max_len) + packed_data_batches.append( { "seq_ctx": seq_ctx, "shifted_labels": shifted_labels, "advantages": advantages, + "rollout_logprobs": rollout_logprobs, } ) return packed_data_batches @@ -152,10 +169,17 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: dtype=packed_data_batches[0]["advantages"].dtype, device="cpu", ) + + pad_rollout_logprobs = None + if 'rollout_logprobs' in packed_data_batches[0] and packed_data_batches[0]['rollout_logprobs'] is not None: + pad_rollout_logprobs = torch.zeros( + 1, pack_max_length, dtype=packed_data_batches[0]['rollout_logprobs'].dtype, device="cpu" + ) pad_data = { "seq_ctx": pad_seq_ctx, "shifted_labels": pad_shifted_labels, "advantages": pad_advantages, + "rollout_logprobs": pad_rollout_logprobs, } pad_data_samples = [pad_data for _ in range(pad_num)] packed_data_batches = packed_data_batches + pad_data_samples diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 630e15421..529a34863 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -131,6 +131,7 @@ class WorkerInputItem(TypedDict): seq_ctx: SequenceContext shifted_labels: torch.LongTensor advantages: torch.Tensor + rollout_logprobs: torch.Tensor | None class TrainingWorker(SingleAcceleratorWorker): @@ -261,6 +262,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): seq_ctx_list: list[SequenceContext] = [] loss_ctx_input_list: list[RLLossContextInputItem] = [] + rollout_logprobs_list: list[torch.Tensor | None] = [] for data in data_batches: seq_ctx = data["seq_ctx"].to(DEVICE) loss_ctx_input = RLLossContextInputItem( @@ -272,6 +274,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh) seq_ctx_list.append(seq_ctx) loss_ctx_input_list.append(loss_ctx_input) + rollout_logprobs_list.append(data["rollout_logprobs"]) del data_batches diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index f27d124da..517d61fb7 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -326,11 +326,13 @@ def _prepare_train_data(self, data_groups, pack_max_length): advantages_list = [] prompt_len_list = [] response_len_list = [] + rollout_logprobs_list = [] data_batches = [] for group in data_groups: - text_prompt = self.tokenizer.apply_chat_template(group[0]["messages"], tokenize=False, add_generation_prompt=True) - prompt_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"].flatten().tolist() + text_prompt = self.tokenizer.apply_chat_template(group[0]["messages"], tokenize=False, + add_generation_prompt=True) + prompt_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] # prompt = self.tokenizer.apply_chat_template( # group[0]["messages"], add_generation_prompt=True, tokenize=False @@ -348,6 +350,7 @@ def _prepare_train_data(self, data_groups, pack_max_length): response_ids = group[i]['response_ids'] if isinstance(response_ids, torch.Tensor): response_ids = response_ids.flatten().tolist() + rollout_logprobs_list.extend(group[i]["logprobs"]) else: response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() input_ids = prompt_ids + response_ids @@ -362,11 +365,19 @@ def _prepare_train_data(self, data_groups, pack_max_length): shifted_labels = shifted_labels[:pack_max_length] input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) + + if len(rollout_logprobs_list) > 0: + rollout_logprobs = torch.tensor(rollout_logprobs_list, dtype=torch.float32).unsqueeze(0) + assert rollout_logprobs.size() == shifted_labels.size() + else: + rollout_logprobs = None + data_batches.append( dict( seq_ctx=SequenceContext.from_input_ids((input_ids,), device="cpu"), shifted_labels=shifted_labels, advantage=advantages[i].item(), + rollout_logprobs=rollout_logprobs, ) ) random.shuffle(data_batches) From d8168fada917b5191770ee1f2216c4e3bbe994b5 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Mon, 29 Sep 2025 16:32:32 +0800 Subject: [PATCH 20/22] fix logprob print --- xtuner/v1/rl/base/worker.py | 46 +++++++++++++++++++++++++++++++++-- xtuner/v1/train/rl_trainer.py | 16 ++++++++---- 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 529a34863..cdeabf3ab 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -274,7 +274,8 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh) seq_ctx_list.append(seq_ctx) loss_ctx_input_list.append(loss_ctx_input) - rollout_logprobs_list.append(data["rollout_logprobs"]) + if "rollout_logprobs" in data and data["rollout_logprobs"] is not None: + rollout_logprobs_list.append(data["rollout_logprobs"].to(DEVICE)) del data_batches @@ -290,10 +291,51 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): # old logprobs are inplaced updated in compute_actor_logprobs loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list) sum_entropy: torch.Tensor | None = None - for loss_ctx_input in loss_ctx_input_list: + if len(rollout_logprobs_list) > 0: + assert len(rollout_logprobs_list) == len( + loss_ctx_input_list), f'rollout_logprobs_list {len(rollout_logprobs_list)} vs loss_ctx_input_list {len(loss_ctx_input_list)}' + + all_diffs = [] + for i, loss_ctx_input in enumerate(loss_ctx_input_list): mask = loss_ctx_input.shifted_labels != -100 entropy = -(cast(torch.Tensor, loss_ctx_input.old_logprobs) * mask).sum() sum_entropy = entropy if sum_entropy is None else sum_entropy + entropy + + if len(rollout_logprobs_list) > 0: + rollout_logprobs = rollout_logprobs_list[i][mask] + old_logprobs = loss_ctx_input.old_logprobs[mask] + + # 计算差异 + assert len( + rollout_logprobs.size()) == 1, f"len(rollout_logprobs.size()): {len(rollout_logprobs.size())}" + assert rollout_logprobs.shape == old_logprobs.shape, f'rollout_logprobs {rollout_logprobs.shape} vs old_logprobs {old_logprobs.shape}' + if rollout_logprobs.numel() == 0: # pad 情况下是空的 + min_diff = torch.tensor(0) + max_diff = min_diff + std_diff = min_diff + mean_diff = min_diff + else: + min_diff = torch.min(rollout_logprobs - old_logprobs) + max_diff = torch.max(rollout_logprobs - old_logprobs) + mean_diff = torch.mean(rollout_logprobs - old_logprobs) + if rollout_logprobs.numel() == 1: + std_diff = torch.tensor(0) + else: + std_diff = torch.std(rollout_logprobs - old_logprobs) + all_diffs.append((min_diff, max_diff, mean_diff, std_diff)) + + if len(rollout_logprobs_list) > 0: + all_diffs_tensor = torch.stack([torch.tensor(d).to(DEVICE) for d in all_diffs]) # n, 4 + min_diff = torch.min(all_diffs_tensor[:, 0]).item() + max_diff = torch.max(all_diffs_tensor[:, 1]).item() + mean_diff = torch.mean(all_diffs_tensor[:, 2]).item() + if all_diffs_tensor[:, 3].numel() <= 1: + std_diff = 0 + else: + std_diff = torch.std(all_diffs_tensor[:, 3]).item() + logger.info( + f"Rollout {rollout_idx}: logprobs diff min {min_diff:.4f}, max {max_diff:.4f}, mean {mean_diff:.4f}, std {std_diff:.4f}") + sum_entropy = cast(torch.Tensor, sum_entropy) dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM) avg_gen_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else 0 diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 517d61fb7..ec2ee65f1 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -326,7 +326,6 @@ def _prepare_train_data(self, data_groups, pack_max_length): advantages_list = [] prompt_len_list = [] response_len_list = [] - rollout_logprobs_list = [] data_batches = [] for group in data_groups: @@ -346,11 +345,15 @@ def _prepare_train_data(self, data_groups, pack_max_length): prompt_repeat_k = len(group) for i in range(prompt_repeat_k): item = group[i]["response_str"] + logprobs = None if 'response_ids' in group[i] and group[i]['response_ids'] is not None: response_ids = group[i]['response_ids'] if isinstance(response_ids, torch.Tensor): response_ids = response_ids.flatten().tolist() - rollout_logprobs_list.extend(group[i]["logprobs"]) + logprobs = group[i]["logprobs"] + assert len(logprobs) == len(response_ids), f'{len(logprobs)} vs {len(response_ids)}' + # 只有 response 部分有 logprobs, 需要前面追加 + logprobs = [0] * (len(prompt_ids) - 1) + logprobs + [0] else: response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() input_ids = prompt_ids + response_ids @@ -361,14 +364,17 @@ def _prepare_train_data(self, data_groups, pack_max_length): shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + [-100] if len(input_ids) > pack_max_length: + print('....... 这是不正常的 ......') input_ids = input_ids[:pack_max_length] shifted_labels = shifted_labels[:pack_max_length] + if logprobs is not None: + logprobs = logprobs[:pack_max_length] input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) - if len(rollout_logprobs_list) > 0: - rollout_logprobs = torch.tensor(rollout_logprobs_list, dtype=torch.float32).unsqueeze(0) - assert rollout_logprobs.size() == shifted_labels.size() + if logprobs is not None: + rollout_logprobs = torch.tensor(logprobs, dtype=torch.float32).unsqueeze(0) + assert rollout_logprobs.size() == shifted_labels.size(), f'{rollout_logprobs.size()} vs {shifted_labels.size()}' else: rollout_logprobs = None From 9f8a8637dc310255342f4de9ff887d61c173afd4 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Mon, 29 Sep 2025 19:25:48 +0800 Subject: [PATCH 21/22] update dump --- xtuner/v1/ray/rollout/sglang.py | 17 ++++-- xtuner/v1/ray/rollout/worker.py | 103 +++++++++++++++++--------------- xtuner/v1/train/rl_trainer.py | 45 ++++++++++++++ 3 files changed, 114 insertions(+), 51 deletions(-) diff --git a/xtuner/v1/ray/rollout/sglang.py b/xtuner/v1/ray/rollout/sglang.py index 1a55e9a08..5d3d585f0 100644 --- a/xtuner/v1/ray/rollout/sglang.py +++ b/xtuner/v1/ray/rollout/sglang.py @@ -50,7 +50,8 @@ async def _create_request( "Authorization": f"Bearer {self.api_keys}", # 如果需要鉴权 } if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': - payload = {"model": self.model_name, "stream": True, "return_logprob": True} + stream = False + payload = {"model": self.model_name, "stream": stream, "return_logprob": True} text_prompt = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) prompt_token_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] payload["input_ids"] = prompt_token_ids @@ -61,14 +62,15 @@ async def _create_request( "top_k": sample_params['top_k'], "no_stop_trim": True, "skip_special_tokens": False, - "spaces_between_special_tokens":False, + "spaces_between_special_tokens": False, } payload['sampling_params'] = new_sample_params else: + stream = True payload = { "model": self.model_name, "messages": prompt, - "stream": True, + "stream": stream, } payload.update(sample_params) payload.update(extra_params) @@ -78,7 +80,14 @@ async def _create_request( headers=headers, json=payload, ) - r = await self.client.send(req, stream=True) + r = await self.client.send(req, stream=stream) + + if stream == False: + r.raise_for_status() + try: + r = r.json() + except: + r = r.text return r def _make_request(self, endpoint: str, payload=None): diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py index a20654d00..4b9a0d50a 100644 --- a/xtuner/v1/ray/rollout/worker.py +++ b/xtuner/v1/ray/rollout/worker.py @@ -358,54 +358,62 @@ async def rollout_task( response="", finish_reason="failed", ) - if response.status_code != 200: - error_body = await response.atext() - self.logger.error(f"Request {uid} failed with status {response.status_code}: {error_body}") - return failed_rollout_response - last_trajectory = "" finish_reason = "" - async for chunk in response.aiter_lines(): - # chunk example - # data: {"id":"1","object":"chat.completion.chunk","created":1757495636,"model":"qwen3-8b","choices":[{"index":0,"delta":{"role":"assistant","content":"","reasoning_content":null,"tool_calls":[]},"logprobs":null,"finish_reason":null}],"usage":null} - if not chunk.startswith("data:"): - continue - try: - chunk_data_str = chunk[len("data:"):].strip() - if self.paused or chunk_data_str == "[DONE]": - finish_reason = "paused" if self.paused else finish_reason - break - if not (chunk_data_str.startswith("{") and chunk_data_str.endswith("}")): - continue - - chunk_data = json.loads(chunk_data_str) - if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': - # TODO: 不太懂,好像是一个假的流式,每次返回的都会包括之前的? - last_trajectory = chunk_data['text'] - if "output_token_logprobs" in chunk_data["meta_info"]: - new_response_tokens = [item[1] for item in chunk_data["meta_info"]["output_token_logprobs"]] - new_response_log_probs = [item[0] for item in chunk_data["meta_info"]["output_token_logprobs"]] - finish_reason = chunk_data["meta_info"].get("finish_reason") - if finish_reason is not None: - assert isinstance(finish_reason, dict) - finish_reason = finish_reason['type'] - else: - delta_content = chunk_data["choices"][0]["delta"].get("content") - last_trajectory = last_trajectory + delta_content if delta_content else last_trajectory - finish_reason = chunk_data["choices"][0].get("finish_reason") - - # todo(@duanyanhui): remove appending stop tokens manually after lmdeploy support return stop_token_ids. - if finish_reason == "stop": - assert len(sample_params["stops"]) == 1 - last_trajectory += sample_params["stops"][0] - - except json.JSONDecodeError as e: - self.logger.error(f"JSON decode error for chunk in request {uid}: {chunk}, error: {e}") - continue - except Exception as e: - self.logger.error(f"Error processing chunk for {uid}: {chunk}, error: {e}") + if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': + # 非 stream 模式 + if "output_token_logprobs" in response["meta_info"]: + new_response_tokens = [item[1] for item in response["meta_info"]["output_token_logprobs"]] + new_response_log_probs = [item[0] for item in response["meta_info"]["output_token_logprobs"]] + assert len(new_response_tokens) <= sample_params["max_tokens"], f'生成长度超过限制,生成长度 {len(new_response_tokens)},限制 {sample_params["max_tokens"]}' + last_trajectory = response['text'] + finish_reason = response["meta_info"]["finish_reason"]["type"] + else: + if response.status_code != 200: + error_body = await response.atext() + self.logger.error(f"Request {uid} failed with status {response.status_code}: {error_body}") return failed_rollout_response + async for chunk in response.aiter_lines(): + # chunk example + # data: {"id":"1","object":"chat.completion.chunk","created":1757495636,"model":"qwen3-8b","choices":[{"index":0,"delta":{"role":"assistant","content":"","reasoning_content":null,"tool_calls":[]},"logprobs":null,"finish_reason":null}],"usage":null} + if not chunk.startswith("data:"): + continue + try: + chunk_data_str = chunk[len("data:"):].strip() + if self.paused or chunk_data_str == "[DONE]": + finish_reason = "paused" if self.paused else finish_reason + break + if not (chunk_data_str.startswith("{") and chunk_data_str.endswith("}")): + continue + + chunk_data = json.loads(chunk_data_str) + if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': + # TODO: 不太懂,好像是一个假的流式,每次返回的都会包括之前的? + last_trajectory = chunk_data['text'] + if "output_token_logprobs" in chunk_data["meta_info"]: + new_response_tokens = [item[1] for item in chunk_data["meta_info"]["output_token_logprobs"]] + new_response_log_probs = [item[0] for item in chunk_data["meta_info"]["output_token_logprobs"]] + finish_reason = chunk_data["meta_info"].get("finish_reason") + if finish_reason is not None: + assert isinstance(finish_reason, dict) + finish_reason = finish_reason['type'] + else: + delta_content = chunk_data["choices"][0]["delta"].get("content") + last_trajectory = last_trajectory + delta_content if delta_content else last_trajectory + finish_reason = chunk_data["choices"][0].get("finish_reason") + + # todo(@duanyanhui): remove appending stop tokens manually after lmdeploy support return stop_token_ids. + if finish_reason == "stop": + assert len(sample_params["stops"]) == 1 + last_trajectory += sample_params["stops"][0] + + except json.JSONDecodeError as e: + self.logger.error(f"JSON decode error for chunk in request {uid}: {chunk}, error: {e}") + continue + except Exception as e: + self.logger.error(f"Error processing chunk for {uid}: {chunk}, error: {e}") + return failed_rollout_response assert finish_reason in ["stop", "length", "tool_call", "paused", "failed"], ( f"Unexpected finish_reason: {finish_reason}" @@ -432,9 +440,10 @@ async def rollout_task( self.logger.error(f"An unexpected error occurred in rollout_task for {uid}: {e}") return failed_rollout_response finally: - # 确保在任何情况下都尝试关闭响应 - if response: - await response.aclose() + if os.environ.get("ID_INPUT_OUTPUT", '0') == '0': + # 确保在任何情况下都尝试关闭响应 + if response: + await response.aclose() async def rollout( self, diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index ec2ee65f1..fe10610c6 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -410,15 +410,41 @@ def _prepare_train_data(self, data_groups, pack_max_length): def _save_trajectories(self, data_groups, save_path): rewards = [] response_len_list = [] + + rollout_response_len_list = [] + rollout_id_to_str = [] + revert_rollout_ids = [] for group in data_groups: for data in group: rewards.append(data["reward"]) + if "response_ids" in data and data["response_ids"] is not None: + if isinstance(data["response_ids"], torch.Tensor): + response_ids = data["response_ids"].flatten().tolist() + else: + response_ids = data["response_ids"] + rollout_response_len_list.append(len(response_ids)) + + response_str = self.tokenizer.decode(response_ids, skip_special_tokens=False) + revert_encode_response_ids = self.tokenizer.encode(response_str, add_special_tokens=False) + + if response_ids != revert_encode_response_ids: + print("Warning: response_ids and revert_encode_response_ids are not the same!") + revert_rollout_ids.append(response_ids) + else: + revert_rollout_ids.append(None) + + rollout_id_to_str.append(response_str) + response_ids = self.tokenizer.encode(data["response_str"], add_special_tokens=False) response_len_list.append(len(response_ids)) rewards = torch.tensor(rewards).float() response_lens = torch.tensor(response_len_list).float() + rollout_response_lens = None + if len(rollout_response_len_list) > 0: + rollout_response_lens = torch.tensor(rollout_response_len_list).float() + _count = 0 with open(save_path, "w", encoding="utf-8") as f: item = { @@ -432,6 +458,14 @@ def _save_trajectories(self, data_groups, save_path): "response_len_min": response_lens.min().item(), "total_len": len(rewards), } + if len(rollout_response_len_list) > 0: + item.update({ + "rollout_response_len_mean": rollout_response_lens.mean().item(), + "rollout_response_len_std": rollout_response_lens.std().item(), + "rollout_response_len_max": rollout_response_lens.max().item(), + "rollout_response_len_min": rollout_response_lens.min().item(), + }) + json.dump(item, f, ensure_ascii=False, indent=2) f.write("\n") for group in data_groups: @@ -443,6 +477,17 @@ def _save_trajectories(self, data_groups, save_path): "label": data["reward_model"]["ground_truth"], "reward": data["reward"], } + + if len(rollout_id_to_str) > 0: + revert_rollout_id = revert_rollout_ids[_count] + if revert_rollout_id is not None: + item["rollout_ids"] = revert_rollout_id + + if response_len_list[_count] != rollout_response_len_list[_count]: + item["rollout_response_len"] = rollout_response_len_list[_count] + if len(rollout_id_to_str) > 0 and data["response_str"] != rollout_id_to_str[_count]: + item["rollout_id_to_str"] = rollout_id_to_str[_count] + json.dump(item, f, ensure_ascii=False, indent=2) f.write("\n") _count += 1 From 007807fca4583fb560adba17b786354eee8152b5 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 30 Sep 2025 17:48:30 +0800 Subject: [PATCH 22/22] update lmdeploy --- ci/scripts/test_dapo.sh | 2 ++ xtuner/v1/ray/rollout/lmdeploy.py | 54 ++++++++++++++++++++++++------- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/ci/scripts/test_dapo.sh b/ci/scripts/test_dapo.sh index 23d9870a2..b4786c850 100644 --- a/ci/scripts/test_dapo.sh +++ b/ci/scripts/test_dapo.sh @@ -6,6 +6,8 @@ export PYTHONPATH=/mnt/shared-storage-user/huanghaian/code/lmdeploy/:$PYTHONPATH export UVICORN_LOG_LEVEL="CRITICAl" export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' +#export ID_INPUT_OUTPUT=1 + OUTPUT_DIR='work_dirs/dapo_math_7B_newlmdeploy_nogroup' if [ ! -d "$OUTPUT_DIR" ]; then mkdir -p "$OUTPUT_DIR" diff --git a/xtuner/v1/ray/rollout/lmdeploy.py b/xtuner/v1/ray/rollout/lmdeploy.py index dbd4c4752..34d6091dc 100644 --- a/xtuner/v1/ray/rollout/lmdeploy.py +++ b/xtuner/v1/ray/rollout/lmdeploy.py @@ -5,7 +5,7 @@ import ray import requests from ray.util.placement_group import placement_group_table - +from transformers import AutoTokenizer from xtuner.v1.ray.config import RolloutConfig from .worker import RolloutWorker @@ -62,7 +62,11 @@ def __init__( self.server_func = run_lmdeploy_server_wrapper self.router_func_str = "lmdeploy.serve.proxy.proxy.proxy" self.endpoints["health_generate"] = "health" - self.endpoints["generate"] = "v1/chat/completions" + if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': + self.endpoints["generate"] = "generate" + self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) + else: + self.endpoints["generate"] = "v1/chat/completions" self.endpoints["output_ids"] = "output_ids" self.endpoints["response"] = "text" self.endpoints["sleep"] = "sleep" @@ -98,22 +102,45 @@ async def _create_request( "Content-Type": "application/json", "Authorization": f"Bearer {self.api_keys}", # 如果需要鉴权 } - payload = { - "model": self.model_name, - "messages": prompt, - "tools": tools, - "tool_choice": tool_choice, - "stream": True, - } - payload.update(sample_params) - payload.update(extra_params) + if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': + stream = False + payload = {"model": self.model_name, "stream": stream, "return_logprob": True} + text_prompt = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + prompt_token_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] + payload["input_ids"] = prompt_token_ids + + new_sample_params = {"max_tokens": sample_params['max_tokens'], + "temperature": sample_params['temperature'], + "top_p": sample_params['top_p'], + "top_k": sample_params['top_k'], + "include_stop_str_in_output": True, + "skip_special_tokens": False, + "spaces_between_special_tokens": False, + } + payload.update(new_sample_params) + else: + stream = True + payload = { + "model": self.model_name, + "messages": prompt, + "stream": stream, + } + payload.update(sample_params) + payload.update(extra_params) req = self.client.build_request( "POST", url, headers=headers, json=payload, ) - r = await self.client.send(req, stream=True) + r = await self.client.send(req, stream=stream) + + if stream == False: + r.raise_for_status() + try: + r = r.json() + except: + r = r.text return r def get_logprobs(self, input_ids, sampling_params): @@ -278,6 +305,9 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: lmdeploy_config_kwargs["log_level"] = "CRITICAL" # disable logging + if os.environ.get("ID_INPUT_OUTPUT", '0') == '1': + backend_config.logprobs_mode = 'raw_logprobs' # TODO: 只支持 pytorch 后端 + return Namespace( model_path=self.config.model_path, model_name=self.model_name,