diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 533f55d332b..2ab905806c4 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -20,6 +20,7 @@ /verl/workers/actor/megatron_actor.py @ISEEKYAN @vermouth1992 /verl/workers/critic/megatron_critic.py @ISEEKYAN @vermouth1992 /verl/workers/megatron_workers.py @ISEEKYAN @vermouth1992 +/verl/experimental @wuxibin89 @ArronHZG /tests/single_controller @zw0610 @wuxibin89 /tests/trainer @eric-haibin-lin @vermouth1992 @tongyx361 @PeterSH6 diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 2db9aa020fa..91c0d21f2a3 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -6,7 +6,7 @@ - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` + - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. diff --git a/.github/workflows/e2e_ascend.yml b/.github/workflows/e2e_ascend.yml index c7aab7be9cd..578a67d3e2b 100644 --- a/.github/workflows/e2e_ascend.yml +++ b/.github/workflows/e2e_ascend.yml @@ -126,6 +126,10 @@ jobs: ray stop --force export PYTHONPATH=$PYTHONPATH:/Megatron-LM USE_DIST_CKPT=True USE_DUMMY_MODEL=True DUMMY_MODEL_CONFIG_PATH=tests/special_e2e/ppo_trainer/expert_parallel/qwen3moe_minimal.json DUMMY_MODEL_PATH=$HOME/dist_ckpt/qwen3_30b_grpo_mindspeed bash tests/special_npu/run_qwen3_30b_grpo_mindspeed.sh + - name: Running the E2E test with fully_async_policy algorithm (FSDP2) + run: | + ray stop --force + bash tests/special_npu/run_fully_async_policy.sh vlm_rl_job: if: github.repository_owner == 'verl-project' diff --git a/.github/workflows/e2e_one_step_off_policy_ascend.yml b/.github/workflows/e2e_one_step_off_policy_ascend.yml index 3738ede9331..f1b0054a33a 100644 --- a/.github/workflows/e2e_one_step_off_policy_ascend.yml +++ b/.github/workflows/e2e_one_step_off_policy_ascend.yml @@ -68,7 +68,7 @@ on: # Entrypoints - ".github/workflows/e2e_one_step_off_policy_ascend.yml" - "examples/data_preprocess/gsm8k.py" - - "tests/special_e2e/run_one_step_off_policy.sh" + - "tests/special_npu/run_one_step_off_policy.sh" # Cancel jobs on the same ref if a new one is triggered concurrency: @@ -122,7 +122,7 @@ jobs: - name: Running the E2E test with one_step_off_policy algorithm (FSDP2) run: | ray stop --force - bash tests/special_e2e/run_one_step_off_policy.sh + bash tests/special_npu/run_one_step_off_policy.sh # Test Megatron strategy e2e_one_step_off_policy_megatron_ascend: @@ -167,4 +167,4 @@ jobs: run: | ray stop --force export PYTHONPATH=$PYTHONPATH:/Megatron-LM - bash tests/special_e2e/run_one_step_off_policy.sh + bash tests/special_npu/run_one_step_off_policy.sh diff --git a/.github/workflows/e2e_ppo_trainer_veomni_vllm.yml b/.github/workflows/e2e_ppo_trainer_veomni_vllm.yml index 04be1af3b15..f2cdacd0f31 100644 --- a/.github/workflows/e2e_ppo_trainer_veomni_vllm.yml +++ b/.github/workflows/e2e_ppo_trainer_veomni_vllm.yml @@ -134,7 +134,7 @@ jobs: - name: Running GEO3K E2E training tests on 8 L20 GPUs with veomni engine (FSDP_SIZE=8, USP=1) run: | ray stop --force - MODEL_ID=Qwen/Qwen3-VL-2B-Instruct TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/gsm8k/test.parquet VAL_BEFORE_TRAIN=True NUM_GPUS=8 FSDP_SIZE=8 SP_SIZE=1 EP_SIZE=1 VERL_EXP_NAME="qwen3-2b-vl-function-reward-minimal-fsdp-size8" bash tests/special_e2e/run_ppo_trainer_veomni.sh + MODEL_ID=Qwen/Qwen3-VL-2B-Instruct TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/gsm8k/test.parquet VAL_BEFORE_TRAIN=True NUM_GPUS=8 FSDP_SIZE=4 SP_SIZE=2 EP_SIZE=1 VERL_EXP_NAME="qwen3-2b-vl-function-reward-minimal-fsdp-size8" bash tests/special_e2e/run_ppo_trainer_veomni.sh cleanup: runs-on: ubuntu-latest diff --git a/.github/workflows/e2e_sft_llm.yml b/.github/workflows/e2e_sft_llm.yml index 448e433bdea..66515a422d2 100644 --- a/.github/workflows/e2e_sft_llm.yml +++ b/.github/workflows/e2e_sft_llm.yml @@ -110,7 +110,7 @@ jobs: - name: Prepare gsm8k dataset run: | ray stop --force - python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k + python3 examples/data_preprocess/gsm8k_multiturn_sft.py --local_dataset_path ${HOME}/models/hf_data/gsm8k - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm run: | ray stop --force @@ -123,10 +123,6 @@ jobs: run: | ray stop --force SP_SIZE=2 bash tests/special_e2e/sft/run_sft.sh - - name: Check loss difference between sequence parallel vs. default implementation - run: | - ray stop --force - ENTRYPOINT="tests/special_e2e/sft/test_sp_loss_match.py" SP_SIZE=2 bash tests/special_e2e/sft/run_sft.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with sequence parallism and liger run: | ray stop --force @@ -140,10 +136,6 @@ jobs: ray stop --force LORA_RANK=32 RESUME_MODE=auto TOTAL_TRAIN_STEP=2 bash tests/special_e2e/sft/run_sft.sh # TODO: multiturn - - name: Prepare gsm8k dataset - run: | - ray stop --force - python3 examples/data_preprocess/gsm8k_multiturn_sft.py --local_dataset_path ${HOME}/models/hf_data/gsm8k - name: Running GSM8K E2E training tests with multiturn and various configs and compare results run: | bash tests/special_e2e/sft/test_sft_engine_all.sh diff --git a/.github/workflows/e2e_sft_llm_ascend.yml b/.github/workflows/e2e_sft_llm_ascend.yml index 825fb265647..4ccd074cefa 100644 --- a/.github/workflows/e2e_sft_llm_ascend.yml +++ b/.github/workflows/e2e_sft_llm_ascend.yml @@ -109,7 +109,7 @@ jobs: ln -s /root/.cache/models ~/models - name: Prepare gsm8k dataset run: | - python examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k + python3 examples/data_preprocess/gsm8k_multiturn_sft.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k - name: Running GSM8K E2E training tests on 8 NPUs with rmpad using function rm run: | ray stop --force @@ -122,10 +122,6 @@ jobs: run: | ray stop --force SP_SIZE=2 bash tests/special_e2e/sft/run_sft.sh - - name: Check loss difference between sequence parallel vs. default implementation - run: | - ray stop --force - ENTRYPOINT="tests/special_e2e/sft/test_sp_loss_match.py" SP_SIZE=2 bash tests/special_e2e/sft/run_sft.sh - name: Running GSM8K E2E training tests with LoRA run: | ray stop --force @@ -134,11 +130,6 @@ jobs: run: | ray stop --force LORA_RANK=32 RESUME_MODE=auto TOTAL_TRAIN_STEP=2 bash tests/special_e2e/sft/run_sft.sh - # TODO: multiturn - - name: Prepare gsm8k dataset - run: | - ray stop --force - python3 examples/data_preprocess/gsm8k_multiturn_sft.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k - name: Running GSM8K E2E training tests with multiturn and various configs and compare results run: | export PYTHONPATH=$PYTHONPATH:/Megatron-LM diff --git a/.github/workflows/e2e_transferqueue.yml b/.github/workflows/e2e_transferqueue.yml deleted file mode 100644 index 309f715be1b..00000000000 --- a/.github/workflows/e2e_transferqueue.yml +++ /dev/null @@ -1,172 +0,0 @@ -# # Tests layout - -# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: -# - `tests/trainer` for testing functionality related to `verl/trainer` -# - `tests/models` for testing functionality related to `verl/models` -# - ... - -# There are a few folders with `special_` prefix, created for special purposes: -# - `special_distributed`: unit tests that must run with multiple GPUs -# - `special_e2e`: end-to-end tests with training/generation scripts -# - `special_npu`: tests for NPUs -# - `special_sanity`: a suite of quick sanity tests -# - `special_standalone`: a set of test that are designed to run in dedicated environments - -# Accelerators for tests -# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. -# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. - -# # Workflow layout - -# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: -# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` -# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` -# 3. End-to-end tests: `e2e_*.yml` -# 4. Unit tests -# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` -# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. -# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when -# - new workflow yaml is added to `.github/workflows` -# - new tests are added to workflow mentioned in 2. - -name: e2e_transferqueue - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - # For push, for now only anti-patterns are specified so it is more conservative - # and achieves higher coverage. - push: - branches: - - main - - v0.* - paths: - - "**/*.py" - - "!**/*.md" - - "!**/*.sh" - # Other entrypoints - - "!examples/*trainer*" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - - "verl/experimental/transfer_queue/**" - pull_request: - branches: - - main - - v0.* - paths: - - "**/*.py" - - "!**/*.md" - - "!**/*.sh" - # Other entrypoints - - "!examples/**" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - # Home - - "verl/experimental/transfer_queue" - # Entrypoints - - ".github/workflows/e2e_transferqueue.yml" - - "examples/data_preprocess/gsm8k.py" - - "tests/special_e2e/run_transferqueue.sh" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -env: - IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:vllm012.dev3" - DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" - -jobs: - setup: - if: github.repository_owner == 'verl-project' - runs-on: ubuntu-latest - outputs: - runner-label: ${{ steps.create-runner.outputs.runner-label }} - mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} - steps: - - uses: actions/checkout@v4 - - id: create-runner - uses: volcengine/vemlp-github-runner@v1 - with: - mode: "create" - faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" - mlp-image: "${{ env.IMAGE }}" - - # Test FSDP strategy - e2e_transferqueue_fsdp: - needs: setup - runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] - timeout-minutes: 10 # Increase timeout for async training - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - ACTOR_STRATEGY: "fsdp" - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -r requirements-test.txt - pip3 install --no-deps -e . - pip3 install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.5 - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k - - name: Running the E2E test with TransferQueue (FSDP), enable zero copy serialization - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=True - bash tests/special_e2e/run_transferqueue.sh - - # Test Megatron strategy - e2e_transferqueue_megatron: - needs: setup - runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] - timeout-minutes: 10 # Increase timeout for async training - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - ACTOR_STRATEGY: "megatron" - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -r requirements-test.txt - pip3 install --no-deps -e . - pip3 install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.5 - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k - - name: Running the E2E test with TransferQueue (Megatron), disable zero copy serialization - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=False - bash tests/special_e2e/run_transferqueue.sh - - cleanup: - runs-on: ubuntu-latest - needs: [setup, e2e_transferqueue_fsdp, e2e_transferqueue_megatron] - if: always() - steps: - - id: destroy-runner - uses: volcengine/vemlp-github-runner@v1 - with: - mode: "destroy" - faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" - mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" diff --git a/.github/workflows/e2e_transferqueue_ascend.yml b/.github/workflows/e2e_transferqueue_ascend.yml deleted file mode 100644 index ca1dc2de87e..00000000000 --- a/.github/workflows/e2e_transferqueue_ascend.yml +++ /dev/null @@ -1,174 +0,0 @@ -# # Tests layout - -# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: -# - `tests/trainer` for testing functionality related to `verl/trainer` -# - `tests/models` for testing functionality related to `verl/models` -# - ... - -# There are a few folders with `special_` prefix, created for special purposes: -# - `special_distributed`: unit tests that must run with multiple GPUs -# - `special_e2e`: end-to-end tests with training/generation scripts -# - `special_npu`: tests for NPUs -# - `special_sanity`: a suite of quick sanity tests -# - `special_standalone`: a set of test that are designed to run in dedicated environments - -# Accelerators for tests -# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. -# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. - -# # Workflow layout - -# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: -# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` -# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` -# 3. End-to-end tests: `e2e_*.yml` -# 4. Unit tests -# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` -# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. -# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when -# - new workflow yaml is added to `.github/workflows` -# - new tests are added to workflow mentioned in 2. - -name: e2e_transferqueue_ascend - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - # For push, for now only anti-patterns are specified so it is more conservative - # and achieves higher coverage. - push: - branches: - - main - - v0.* - paths: - - "**/*.py" - - "!**/*.md" - - "!**/*.sh" - # Other entrypoints - - "!examples/*trainer*" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - - "verl/experimental/transfer_queue/**" - pull_request: - branches: - - main - - v0.* - paths: - - "**/*.py" - - "!**/*.md" - - "!**/*.sh" - # Other entrypoints - - "!examples/**" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - # Home - - "verl/experimental/transfer_queue" - # Entrypoints - - ".github/workflows/e2e_transferqueue_ascend.yml" - - "examples/data_preprocess/gsm8k.py" - - "tests/special_e2e/run_transferqueue.sh" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - # Test FSDP strategy - e2e_transferqueue_fsdp_ascend: - if: github.repository_owner == 'verl-project' - runs-on: linux-aarch64-a2-8 - timeout-minutes: 60 # Increase this timeout value as needed - container: - image: swr.ap-southeast-1.myhuaweicloud.com/base_image/ascend-ci/verl/verl:verl-8.3.rc1-910b-ubuntu22.04-py3.11-latest - options: >- - --shm-size 16g - env: - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - ACTOR_STRATEGY: "fsdp" - steps: - - name: Check npu and CANN info - run: | - cat /usr/local/Ascend/ascend-toolkit/latest/"$(uname -i)"-linux/ascend_toolkit_install.info - npu-smi info - - name: Check initial pip list from image - run: | - pip list - - name: Checkout verl-project/verl repo - uses: actions/checkout@v4 - with: - fetch-depth: 0 - clean: true - - name: Install the current repository - run: | - pip install -r requirements-npu.txt - pip install --no-deps -e . - pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.5 - - name: Check final pip list - run: | - pip list - - name: Prepare weights - run: | - ln -s /root/.cache/models ~/models - - name: Prepare GSM8K dataset - run: | - python examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k - - name: Running the E2E test with TransferQueue (FSDP), enable zero copy serialization - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=True - bash tests/special_e2e/run_transferqueue.sh - - # Test Megatron strategy - e2e_transferqueue_megatron_ascend: - if: github.repository_owner == 'verl-project' - runs-on: linux-aarch64-a2-8 - timeout-minutes: 60 # Increase this timeout value as needed - container: - image: swr.ap-southeast-1.myhuaweicloud.com/base_image/ascend-ci/verl/verl:verl-8.3.rc1-910b-ubuntu22.04-py3.11-latest - options: >- - --shm-size 16g - env: - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - ACTOR_STRATEGY: "megatron" - steps: - - name: Check npu and CANN info - run: | - cat /usr/local/Ascend/ascend-toolkit/latest/"$(uname -i)"-linux/ascend_toolkit_install.info - npu-smi info - - name: Check initial pip list from image - run: | - pip list - - name: Checkout verl-project/verl repo - uses: actions/checkout@v4 - with: - fetch-depth: 0 - clean: true - - name: Install the current repository - run: | - pip install -r requirements-npu.txt - pip install --no-deps -e . - pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.5 - - name: Check final pip list - run: | - pip list - - name: Prepare weights - run: | - ln -s /root/.cache/models ~/models - - name: Prepare GSM8K dataset - run: | - python examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k - - name: Running the E2E test with TransferQueue (Megatron), disable zero copy serialization - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=False - export PYTHONPATH=$PYTHONPATH:/Megatron-LM - bash tests/special_e2e/run_transferqueue.sh diff --git a/.github/workflows/gpu_unit_tests.yml b/.github/workflows/gpu_unit_tests.yml index d16075e19ea..eaafeadb15e 100644 --- a/.github/workflows/gpu_unit_tests.yml +++ b/.github/workflows/gpu_unit_tests.yml @@ -108,7 +108,7 @@ jobs: pip3 install hf_transfer pip3 install -r requirements-test.txt pip3 install --no-deps -e . - pip3 install cupy-cuda12x pytest-asyncio + pip3 install cupy-cuda12x==13.6.0 pytest-asyncio pip3 install --ignore-installed blinker pip3 install --ignore-installed mlflow "numpy<2.0" - name: Run all GPU unit tests diff --git a/.github/workflows/sgl.yml b/.github/workflows/sgl.yml index f9158731bac..2c928a9527e 100644 --- a/.github/workflows/sgl.yml +++ b/.github/workflows/sgl.yml @@ -113,7 +113,7 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install cupy-cuda12x pytest-asyncio + pip3 install cupy-cuda12x==13.6.0 pytest-asyncio pip3 install hf_transfer fastmcp pytest-asyncio pip3 install -r requirements-test.txt pip3 install --no-deps -e . @@ -144,7 +144,7 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install cupy-cuda12x pytest-asyncio + pip3 install cupy-cuda12x==13.6.0 pytest-asyncio pip3 install hf_transfer fastmcp pytest-asyncio pip3 install -r requirements-test.txt pip3 install --no-deps -e . diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index 84cdb45c3f6..15c51678030 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -144,7 +144,7 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install cupy-cuda12x pytest-asyncio + pip3 install pytest-asyncio pip3 install -r requirements-test.txt pip3 install --no-deps -e . pip3 install --upgrade "transformers<5.0" diff --git a/.gitignore b/.gitignore index 62d4dcfc815..e6c0f5a08e3 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ **/playground **/wandb +/pyrightconfig.json + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index cea8b9945e9..a7d04fc5ba5 100644 --- a/README.md +++ b/README.md @@ -237,6 +237,7 @@ Welcome to register your awesome project build with `verl` for other developers' - [DeepResearcher](https://github.com/GAIR-NLP/DeepResearcher): Scaling deep research via reinforcement learning in real-world environments  - [VAGEN](https://github.com/RAGEN-AI/VAGEN): Training VLM agents with multi-turn reinforcement learning  - [RM-R1](https://arxiv.org/abs/2505.02387): RL training of reasoning reward models  +- [Dr. MAS](https://arxiv.org/pdf/2602.08847): Stable **end-to-end RL** post-training for **multi-agent LLM systems**  - [LUFFY](https://arxiv.org/pdf/2504.14945): Learning to Reason under Off-Policy Guidance - [DeepMath](https://github.com/zwhe99/DeepMath): DeepMath-103K data and series models for math reasoning - [PACS](https://github.com/ritzz-ai/PACS): Implicit Actor Critic Coupling via a Supervised Learning Framework for RLVR  @@ -283,6 +284,7 @@ Welcome to register your awesome project build with `verl` for other developers' - [NoisyRollout](https://github.com/NUS-TRAIL/NoisyRollout): Reinforcing Visual Reasoning with Data Augmentation  - [SPEAR](https://github.com/TencentYoutuResearch/SPEAR): **Self-imitation** with **Progressive Exploration** for Agentic Reinforcement Learning (ICLR 2026)  - [RuleReasoner](https://github.com/bigai-nlco/RuleReasoner): **RuleReasoner:** Reinforced Rule-based Reasoning via **Domain-aware Dynamic Sampling** (ICLR 2026)  +- [MetaphorStar](https://metaphorstar.github.io/): **Image Metaphor** Understanding and Reasoning with End-to-End **Visual Reinforcement Learning**  ## Contribution Guide diff --git a/docker/Dockerfile.stable.vllm b/docker/Dockerfile.stable.vllm index 158eccac580..701713d1121 100644 --- a/docker/Dockerfile.stable.vllm +++ b/docker/Dockerfile.stable.vllm @@ -32,6 +32,9 @@ RUN pip install torch==2.9.1 torchvision torchaudio --index-url https://download RUN sed -i '/nvidia-cudnn-cu12/d' /usr/local/lib/python3.12/dist-packages/torch-2.9.1+cu129.dist-info/METADATA RUN pip install --no-deps --force-reinstall nvidia-cudnn-cu12==9.16.0.29 +# NOTE: This installs the `vllm` source code in `/vllm`. +# This might break the (based)pyright type checking. To fix it, add `/vllm` to `extraPaths` in `pyrightconfig.json`. +# c.f. https://docs.basedpyright.com/latest/configuration/config-files/ RUN git clone --depth 1 -b v0.12.0 https://github.com/vllm-project/vllm.git && \ cd vllm && \ find requirements -name "*.txt" -print0 | xargs -0 sed -i '/torch/d' && \ diff --git a/docker/ascend/Dockerfile.ascend_8.3.rc1_a2 b/docker/ascend/Dockerfile.ascend_8.3.rc1_a2 index 2d9993cd921..baa55c0e348 100644 --- a/docker/ascend/Dockerfile.ascend_8.3.rc1_a2 +++ b/docker/ascend/Dockerfile.ascend_8.3.rc1_a2 @@ -18,7 +18,7 @@ RUN ARCH=$(uname -m) && \ fi && \ # Clone libs git clone --depth 1 --branch v0.11.0 https://github.com/vllm-project/vllm.git && \ - git clone --depth 1 --branch v0.11.0rc1 https://github.com/vllm-project/vllm-ascend.git && \ + git clone --depth 1 --branch v0.11.0 https://github.com/vllm-project/vllm-ascend.git && \ git clone https://gitcode.com/Ascend/MindSpeed.git && \ cd MindSpeed && git checkout f2b0977e && cd .. && \ git clone --depth 1 --branch core_v0.12.1 https://github.com/NVIDIA/Megatron-LM.git diff --git a/docker/ascend/Dockerfile.ascend_8.3.rc1_a3 b/docker/ascend/Dockerfile.ascend_8.3.rc1_a3 index bead95cea6c..f5d74a7e70a 100644 --- a/docker/ascend/Dockerfile.ascend_8.3.rc1_a3 +++ b/docker/ascend/Dockerfile.ascend_8.3.rc1_a3 @@ -18,7 +18,7 @@ RUN ARCH=$(uname -m) && \ fi && \ # Clone libs git clone --depth 1 --branch v0.11.0 https://github.com/vllm-project/vllm.git && \ - git clone --depth 1 --branch v0.11.0rc1 https://github.com/vllm-project/vllm-ascend.git && \ + git clone --depth 1 --branch v0.11.0 https://github.com/vllm-project/vllm-ascend.git && \ git clone https://gitcode.com/Ascend/MindSpeed.git && \ cd MindSpeed && git checkout f2b0977e && cd .. && \ git clone --depth 1 --branch core_v0.12.1 https://github.com/NVIDIA/Megatron-LM.git diff --git a/docs/advance/fully_async.md b/docs/advance/fully_async.md index a2b7ccb3aea..e7c8962ad9c 100644 --- a/docs/advance/fully_async.md +++ b/docs/advance/fully_async.md @@ -106,9 +106,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev | `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization | | `async_training.staleness_threshold` | Freshness control | | `async_training.partial_rollout` | Whether to perform partial_rollout | -| `async_training.checkpoint_engine.enable` | Whether to use checkpoint_engine for accelerating, default `True` | -| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | When use checkpoint_engine, whether to overlap broadcast and load_weights, default `False` | -| `async_training.checkpoint_engine.device_buffer_size_M` | When use checkpoint_engine, the user-specific bucket size (MB), default `4096` | | `async_training.use_trainer_do_validate` | Whether use trainer node to do validate process, default `False` | **Further Explanation:** @@ -182,27 +179,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev mode d (async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`. -* `async_training.checkpoint_engine.enable` - - Enabling the checkpoint engine generally reduces synchronization time overhead by more than 60% compared to - the original per-tensor parameter synchronization method. However, assembling buckets incurs additional - temporary GPU memory overhead. - -* `async_training.checkpoint_engine.overlap_broadcast_and_consume` - - Enabling pipeline between the broadcast and load_weights parameters will allocate additional GPU memory. - Since the main time consumption for parameter synchronization is not in the broadcast and load_weights phases, - but in the parameter generation phase (by megatron or FSDP), this option is off by default. - -* `async_training.checkpoint_engine.device_buffer_size_M` - - It controls the size of the memory buffer used for synchronization when the checkpoint-engine is enabled. - The actual `bucket_size` = `max(device_buffer_size_M, maximum parameter tensor size)`. - * When enable `overlap_broadcast_and_consume`, the additional device memory overhead of - trainer rank is `3 * bucket_size`and rollout rank is `2 * bucket_size`。 - * When disable `overlap_broadcast_and_consume`, the additional device memory overhead of - trainer rank is `2 * bucket_size`and rollout rank is `1 * bucket_size`。 - * `async_training.use_trainer_do_validate` It controls whether to use the trainer's `do_validate` method for validation. diff --git a/docs/advance/mtp.md b/docs/advance/mtp.md index b4c5a25c631..5f1698d3ddc 100644 --- a/docs/advance/mtp.md +++ b/docs/advance/mtp.md @@ -2,19 +2,21 @@ **Author**: `https://github.com/meituan-search` -Last updated: 01/30/2026 +Last updated: 02/15/2026 # 1. Scope of Support Currently, RL training can be performed on mimo-7B-RL, Qwen-next, and Deepseek series models based on the MTP architecture. The support rules for training and inference engines are as follows: -- **Training Engine**: Only supports the `mbridge + megatron` combination; other training engines are not compatible at this time; +- **Training Engine**: Only supports the `mbridge/Megatron-Bridge + megatron` combination; other training engines are not compatible at this time; - **Inference Engine**: Compatible with all engines, but the model must be in the corresponding engine's compatibility list; - **Dependency Versions**: - - mbridge: Use the specified branch: [https://github.com/ArronHZG/mbridge/tree/feature/verl_mtp](https://github.com/ArronHZG/mbridge/tree/feature/verl_mtp) (will be merged into the main branch in the future); + - mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (will be merged into the main branch in the future); + + - Megatron-Bridge: Apply the patches and review suggestions from PR if you want to try out mimo-7B-RL: [#2387](https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/2387) (will be merged into the main branch in the future); - megatron: Use the latest dev version (commit: [23e092f41ec8bc659020e401ddac9576c1cfed7e](https://github.com/NVIDIA/Megatron-LM/tree/23e092f41ec8bc659020e401ddac9576c1cfed7e)), which supports MTP + CP training methods. diff --git a/docs/algo/dppo.md b/docs/algo/dppo.md new file mode 100644 index 00000000000..aab22bf7fa6 --- /dev/null +++ b/docs/algo/dppo.md @@ -0,0 +1,96 @@ +# Divergence Proximal Policy Optimization (DPPO) + +Last updated: 02/25/2026. + + +
+
+
+
+
-
-- 软件栈工作在hostcpu,通信算法展开一个个task
-- 每个task调用runtime接口,下发到device的rtsqueue
-- STARS从rstqueue上顺序拿取task
-- 根据task类型分别调用掉SDMA和RDMA引擎。
- **单算子瓶颈**:hostbound 每个task提交是2~5us,一个通信算子有几百个task,单算子场景不会在device上缓存,下发一个执行一个
-
-###### AICpu机制展开
-
-
-
-- host侧不下发一个个task,把通信算子作为一个个kernel,放在通信算子kernel的队列上去。
-- STARS调度kernel队列流上的kernel,把kernel放到AiCPU上去执行。
-- AICPU调用函数(kernel),用一个线程执行kernel 函数,在函数内把通信task展开,把task放到rstqueue上,STARS调用。
-- 降低host和aicpu交互,由几百次降低为一次。
-- task的提交在AICPU上提交,做了提交的部分合并。
-
-##### TASK_QUEUE_ENABLE
-
-**使用方式:**`export TASK_QUEUE_ENABLE=2`
-
-TASK_QUEUE_ENABLE,下发优化,图模式设置为1(即开启图模式的时候这个要设置为1),非图模式设置为2
-
-示意图:
-
-
-
-##### 绑核优化
-
-**使用方式:**`export CPU_AFFINITY_CONF=1`
-
-详细设置原理可看:https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0059.html
-
-#### 其他
-
-以下内容汇总了若干全局环境变量的调优配置。由于这些参数在训练阶段与推理阶段往往都能带来正向收益,且目前尚缺乏足够精细的消融实验来严格区分它们各自对训练或推理的贡献占比,故统一归拢在此,供后续持续监控与进一步拆解分析。
-
-##### 使能jemalloc
-
-使用方式(注意需要先安装jemalloc库):`export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2`
-
-**安装使用教程:**[MindSpeed-RL/docs/install_guide.md · Ascend/MindSpeed-RL - AtomGit | GitCode](https://gitcode.com/Ascend/MindSpeed-RL/blob/master/docs/install_guide.md#高性能内存库-jemalloc-安装)
-
-##### 多流复用
-
-内存方面有优化
-
-使能方式:`export MULTI_STREAM_MEMORY_REUSE=1`
-
-原理介绍:https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0040.html
-
-##### VLLM_ASCEND_ENABLE_FLASHCOMM
-
-使用方式:`export VLLM_ASCEND_ENABLE_FLASHCOMM=1`
-
-启用昇腾 NPU 特有的FLASHCOMM高速通信优化技术
-
-地址:https://vllm-ascend.readthedocs.io/zh-cn/latest/user_guide/release_notes.html
-
-##### VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
-
-使用方式:`export VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE=1`
-
-启用昇腾 NPU针对大模型推理的稠密计算优化
-
-地址:https://vllm-ascend.readthedocs.io/zh-cn/latest/user_guide/release_notes.html
-
-##### VLLM_ASCEND_ENABLE_PREFETCH_MLP
-
-使用方式:`export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1`
-
-启用 MLP 层的权重预取机制
-
-
-
-##### verl框架参数设置
-
-主要是内存方面的一些设置开关(注意,这个里面的优化都或多或少会导致吞吐量有一定程度的劣化)
-
-~~~bash
-# 梯度检查点 (Gradient Checkpointing)
-# 作用: 通过重新计算激活值来节省显存,以计算换内存。在前向传播时不保存中间激活值,反向传播时重新计算,可以显著降低显存占用,允许使用更大的batch size。
-actor_rollout_ref.model.enable_gradient_checkpointing=True
-
-# 参数卸载 (Parameter Offload)
-# 作用: 将模型参数卸载到CPU内存,训练时再加载回GPU。
-actor_rollout_ref.actor.fsdp_config.param_offload=${offload} # True
-actor_rollout_ref.ref.fsdp_config.param_offload=${offload} # True
-
-# 优化器状态卸载 (Optimizer Offload)
-# 作用: 将优化器状态(如Adam的动量)卸载到CPU。优化器状态通常占用大量显存(对于Adam,每个参数需要额外8字节),卸载可以节省显存。
-actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} # True
-
-# 释放推理引擎缓存 (Free Cache Engine)
-# 作用: 在训练阶段释放推理引擎的KV cache和权重。这是3D-HybridEngine的核心优化,允许在同一GPU上交替进行推理和训练,显著降低显存需求。
-actor_rollout_ref.rollout.free_cache_engine=True
-
-# 熵计算优化
-# entropy_checkpointing: 在训练时对熵计算启用重计算,降低显存峰值
-# entropy_from_logits_with_chunking: 分块处理logits张量(如2048 tokens一组),避免一次性加载整个[bsz*seq_len, vocab]张量
-actor_rollout_ref.actor.entropy_checkpointing=True
-actor_rollout_ref.ref.entropy_checkpointing=True
-actor_rollout_ref.actor.entropy_from_logits_with_chunking=True
-actor_rollout_ref.ref.entropy_from_logits_with_chunking=True
-
-# 推理引擎显存配置
-# gpu_memory_utilization: 控制vLLM使用的GPU显存比例(0.90 = 90%)
-# enforce_eager=False: 启用CUDA graphs加速推理,但会占用额外显存
-actor_rollout_ref.rollout.gpu_memory_utilization=0.90
-actor_rollout_ref.rollout.enforce_eager=False
-~~~
-
-### NPU调优参考文章
-
-环境变量相关:[环境变量列表-Ascend Extension for PyTorch6.0.0-昇腾社区](https://www.hiascend.com/document/detail/zh/Pytorch/600/apiref/Envvariables/Envir_001.html)
-
-社区性能调优教程:[性能调优流程-Ascend Extension for PyTorch6.0.0-昇腾社区](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0001.html)
-
+# NPU Qwen3-32B GSPO Optimization Practice
+
+Last updated: 02/26/2026.
+
+本文章对应脚本地址:[qwen3_32b_gspo_npu](https://github.com/volcengine/verl/blob/main/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh)
+
+## 算法适配
+
+GSPO通过将优化颗粒度从**token级**提升到**sequence级**,规避了GRPO会遇到的**方差急剧增大**导致训练不稳定的情况,增加了训练的稳定性,同时该算法也在一定程度上提升了算法的收敛速度。
+
+想要成功在verl仓库中成功调用到GSPO算法,需要进行如下的必要配置
+
+~~~python
+# 核心算法配置
+algorithm.adv_estimator=grpo \ # 使用GRPO优势估计器
+algorithm.use_kl_in_reward=False \ # 不在奖励中添加KL惩罚
+# GSPO策略损失模式
+actor_rollout_ref.actor.policy_loss.loss_mode=gspo \ # 启用GSPO策略损失
+# 极小裁剪范围(GSPO特色)
+actor_rollout_ref.actor.clip_ratio_low=0.0003 \ # 裁剪下界,论文推荐值
+actor_rollout_ref.actor.clip_ratio_high=0.0004 \ # 裁剪上界,论文推荐值
+# KL配置(GSPO不使用KL loss)
+actor_rollout_ref.actor.use_kl_loss=False \ # 禁用KL损失
+actor_rollout_ref.actor.kl_loss_coef=0.0 \ # KL损失系数设为0
+# 序列级损失聚合模式(GSPO核心)
+actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \ # 序列级平均,GSPO论文推荐
+# 批次配置
+actor_rollout_ref.rollout.n=16 \ # 每个prompt生成16个响应(组采样)
+~~~
+
+一般选择入口函数为`verl.trainer.main_ppo`
+
+## 基础环境
+
+当前支持Atlas 800T A3 与 Atlas 900 A3 SuperPoD。完成跑完本次最佳实践需要 4台Atlas 800T A3。关键软件版本可以参考:[Ascend Quickstart](https://github.com/volcengine/verl/blob/main/docs/ascend_tutorial/ascend_quick_start.rst)
+
+### 安装基础环境
+
+| software | version |
+| ------------ | ---------------------------------------------------------- |
+| Python | >= 3.10, <3.12 |
+| CANN | == 8.3.RC1 |
+| torch | == 2.7.1 |
+| torch_npu | == 2.7.1 |
+| verl | main分支 commitId=252d76908b903ad8fb6969eb3a5e5f873c95ea2b |
+| vllm | v0.11.0 |
+| vllm-ascend | v0.11.0-dev |
+| transformers | 4.57.3 |
+
+在本实践中, 我们通过指定 verl 的commit id 以避免引入其他问题
+
+~~~bash
+cd verl
+git checkout 252d76908b903ad8fb6969eb3a5e5f873c95ea2b
+# 指定相应的recipe版本
+git submodule update --init --recursive recipe
+~~~
+
+### 权重获取
+
+从Hugging Face库下载对应的模型权重:[Qwen/Qwen3-32B · Hugging Face](https://huggingface.co/Qwen/Qwen3-32B)
+
+### 数据集准备
+
+~~~bash
+# 下载math-17k数据集
+git clone https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k
+
+# 下载AIME_2024测试数据集
+git clone https://huggingface.co/datasets/Maxwell-Jia/AIME_2024
+~~~
+
+### jemalloc安装
+
+为了确保 Ray 进程能够正常回收内存,需要安装并使能 jemalloc 库进行内存管理。
+
+#### Ubuntu 操作系统
+
+通过操作系统源安装jemalloc(注意: 要求ubuntu版本>=20.04):
+
+```shell
+sudo apt install libjemalloc2
+```
+
+在启动任务前执行如下命令通过环境变量导入jemalloc,需先通过 **find /usr -name libjemalloc.so.2** 确认文件是否存在 :
+
+```shell
+# arm64架构
+export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2
+# x86_64架构
+export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so.2
+```
+
+#### OpenEuler 操作系统
+
+执行如下命令重操作系统源安装jemalloc
+
+```shell
+yum install jemalloc
+```
+
+如果上述方法无法正常安装,可以通过源码编译安装 前往jemalloc官网下载最新稳定版本,官网地址:https://github.com/jemalloc/jemalloc/releases/
+
+```shell
+tar -xvf jemalloc-{version}.tar.bz2
+cd jemalloc-{version}
+./configure --prefix=/usr/local
+make
+make install
+```
+
+在启动任务前执行如下命令通过环境变量导入jemalloc:
+
+```shell
+#根据实际安装路径设置环境变量,例如安装路径为:/usr/local/lib/libjemalloc.so.2,可通过以下命令来设置环境变量(可通过 find /usr -name libjemalloc.so.2 确认文件是否存在)
+export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2
+```
+
+### 多机任务拉起
+
+针对本实践提供的多机任务,可用下面的脚本拉起
+
+~~~bash
+pkill -9 python
+ray stop --force
+rm -rf /tmp/ray
+
+export RAY_DEDUP_LOGS=0
+export HYDRA_FULL_ERROR=1
+export TASK_QUEUE_ENABLE=1
+export HCCL_EXEC_TIMEOUT=3600
+export HCCL_CONNECT_TIMEOUT=3600
+export HCCL_ASYNC_ERROR_HANDLING=0
+export CPU_AFFINITY_CONF=1
+export VLLM_USE_V1=1
+export VLLM_ATTENTION_BACKEND=XFORMERS
+export VLLM_ASCEND_ENABLE_FLASHCOMM=1
+export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1
+export VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE=1
+export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2
+
+# 修改为当前需要跑的用例路径
+DEFAULT_SH="./run_*.sh"
+echo "Use $DEFAULT_SH"
+
+ulimit -n 32768
+mkdir logs
+
+NNODES=4
+NPUS_PER_NODE=16
+# 修改为对应主节点IP
+MASTER_ADDR="IP FOR MASTER NODE"
+# 修改为当前节点的通信网卡
+SOCKET_IFNAME="Your SOCKET IFNAME"
+export HCCL_SOCKET_IFNAME="SOCKET IFNAME FOR CURRENT NODE"
+export GLOO_SOCKET_IFNAME="SOCKET IFNAME FOR CURRENT NODE"
+# 获取当前IP
+CURRENT_IP=$(ifconfig $SOCKET_IFNAME | grep -Eo 'inet (addr:)?([0-9]{1,3}\.){3}[0-9]{1,3}' | awk '{print $NF}')
+if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then
+ # 主节点启动
+ ray start --head --port 6766 --dashboard-host=$MASTER_ADDR --node-ip-address=$CURRENT_IP --dashboard-port=8260 --resources='{"NPU": '$NPUS_PER_NODE'}'
+
+ while true; do
+ ray_status_output=$(ray status)
+ npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1)
+ npu_count_int=$(echo "$npu_count" | awk '{print int($1)}')
+ device_count=$((npu_count_int / $NPUS_PER_NODE))
+
+ # 判断device_count 是否与 NNODES 相等
+ if [ "$device_count" -eq "$NNODES" ]; then
+ echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script."
+ ray status
+ bash $DEFAULT_SH
+ break
+ else
+ echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count"
+ sleep 5
+ fi
+ done
+else
+ # 子节点尝试往主节点注册 ray 直到成功
+ while true; do
+ # 尝试连接 ray 集群
+ ray start --address="$MASTER_ADDR:6766" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP
+
+ # 检查连接是否成功
+ ray status
+ if [ $? -eq 0 ]; then
+ echo "Successfully connected to the Ray cluster!"
+ break
+ else
+ echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..."
+ sleep 5
+ fi
+ done
+fi
+
+sleep 600
+~~~
+
+DEFAULT_SH:修改为训练所用配置 sh 文件路径。在此案例中修改为 [Qwen2.5-32B](https://github.com/volcengine/verl/blob/main/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh) 路径。
+
+NNODES 和 NPUS_PER_NODE:修改为使用节点数量和每个节点 NPU 数量。在此案例中分别为4和16。
+
+MASTER_ADDR:修改为对应主节点 IP。即所有节点的 MASTER_ADDR 应该相同。
+
+SOCKET_IFNAME, HCCL_SOCKET_IFNAME, GLOO_SOCKET_IFNAME: 修改为对应通信网卡,通信网卡可以通过以下命令获取:
+
+```
+ifconfig |grep "$(hostname -I |awk '{print $1}'|awk -F '.' '{print $0}')" -B 1|awk -F ':' '{print$1}' | head -1 | tail -1
+```
+
+## 性能调优
+
+优化从训练、推理、调度和其他四个方面入手。
+
+### 训练
+
+#### 动态bsz
+
+~~~bash
+actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))
+infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))
+~~~
+
+**这个优化点主要调整上面这两个参数,不过需要注意这两个参数调整的太大会导致OOM**
+
+**主要调整**`actor_ppo_max_token_len`,调大了会降低训练的耗时,调整`infer_ppo_max_token_len`没有明显的收益,可以不动
+
+**这两个参数的作用介绍如下:**
+
+**这两个参数用于控制动态批处理(dynamic batch size)模式下每个GPU处理的最大token数量**
+
+- **`actor_ppo_max_token_len`**: Actor模型在PPO更新(前向+反向传播)时每个GPU能处理的最大token数
+- **`infer_ppo_max_token_len`**: 推理阶段(Reference policy和Rollout)计算log概率时每个GPU能处理的最大token数
+
+### 推理
+
+#### ACLgraph+FULL_DECODE_ONLY
+
+推理算子下发方面的优化,平均能有`15%~20%`左右的性能收益。
+
+先看单开**ACLgraph**,如下:
+
+~~~bash
+# 开启ACLgraph+FULL_DECODE_ONLY(注意:当设置此参数为False时,TASK_QUEUE_ENABLE必须设置为1,不然会报错)
+actor_rollout_ref.rollout.enforce_eager=False
+actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_capture_sizes='[8,16,32,64,128]' \
+actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode='FULL_DECODE_ONLY' \
+~~~
+
+`FULL_DECODE_ONLY`开启成功后有如下输出:
+
+
+
+**`cudagraph_capture_sizes`参数设置指南**
+
+cudagraph_capture_sizes设置的值对应的是批大小,这里的批大小不是配置里的DP域对应的那个批次大小,这里是相较于vllm来说的批大小,单位为**token**
+
+默认生成的算法如下,可做参考
+
+
+
+##### 推理后端切换
+
+使用方式:`export VLLM_ATTENTION_BACKEND=XFORMERS`
+
+
+
+注:需要注意某些后端在一些比较老的vllm-ascend版本内并不支持
+
+##### 使能vllm v1版本
+
+使用方式:`export VLLM_USE_V1=1`
+
+可以常开,一般都是正收益。
+
+### 调度
+
+#### AIV
+
+打开方式:设置`export HCCL_OP_EXPANSION_MODE="AIV"`
+
+HCCL_OP_EXPANSION_MODE环境变量用于配置通信算法的编排展开位置,支持如下取值:
+
+- AI_CPU:代表通信算法的编排展开位置在Device侧的AI CPU计算单元。
+- AIV:代表通信算法的编排展开位置在Device侧的Vector Core计算单元。
+- HOST:代表通信算法的编排展开位置为Host侧CPU,Device侧根据硬件型号自动选择相应的调度器。
+- HOST_TS:代表通信算法的编排展开位置为Host侧CPU,Host向Device的Task Scheduler下发任务,Device的Task Scheduler进行任务调度执行。
+
+下面介绍两种展开机制
+
+##### HOST展开
+
+
+
+- 软件栈工作在hostcpu,通信算法展开一个个task
+- 每个task调用runtime接口,下发到device的rtsqueue
+- STARS从rstqueue上顺序拿取task
+- 根据task类型分别调用掉SDMA和RDMA引擎。
+ **单算子瓶颈**:hostbound 每个task提交是2~5us,一个通信算子有几百个task,单算子场景不会在device上缓存,下发一个执行一个
+
+##### AICpu机制展开
+
+
+
+- host侧不下发一个个task,把通信算子作为一个个kernel,放在通信算子kernel的队列上去。
+- STARS调度kernel队列流上的kernel,把kernel放到AiCPU上去执行。
+- AICPU调用函数(kernel),用一个线程执行kernel 函数,在函数内把通信task展开,把task放到rstqueue上,STARS调用。
+- 降低host和aicpu交互,由几百次降低为一次。
+- task的提交在AICPU上提交,做了提交的部分合并。
+
+#### TASK_QUEUE_ENABLE
+
+**使用方式:**`export TASK_QUEUE_ENABLE=2`
+
+TASK_QUEUE_ENABLE,下发优化,图模式设置为1(即开启图模式的时候这个要设置为1),非图模式设置为2
+
+示意图:
+
+
+
+##### 绑核优化
+
+**使用方式:**`export CPU_AFFINITY_CONF=1`
+
+详细设置原理可看:https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0059.html
+
+### 其他
+
+以下内容汇总了若干全局环境变量的调优配置。由于这些参数在训练阶段与推理阶段往往都能带来正向收益,且目前尚缺乏足够精细的消融实验来严格区分它们各自对训练或推理的贡献占比,故统一归拢在此,供后续持续监控与进一步拆解分析。
+
+#### 使能jemalloc
+
+使用方式(注意需要先安装jemalloc库):`export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2`
+
+**安装使用教程:**[MindSpeed-RL/docs/install_guide.md · Ascend/MindSpeed-RL - AtomGit | GitCode](https://gitcode.com/Ascend/MindSpeed-RL/blob/master/docs/install_guide.md#高性能内存库-jemalloc-安装)
+
+#### 多流复用
+
+内存方面有优化
+
+使能方式:`export MULTI_STREAM_MEMORY_REUSE=1`
+
+原理介绍:https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0040.html
+
+#### VLLM_ASCEND_ENABLE_FLASHCOMM
+
+使用方式:`export VLLM_ASCEND_ENABLE_FLASHCOMM=1`
+
+启用昇腾 NPU 特有的FLASHCOMM高速通信优化技术
+
+地址:https://vllm-ascend.readthedocs.io/zh-cn/latest/user_guide/release_notes.html
+
+#### VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
+
+使用方式:`export VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE=1`
+
+启用昇腾 NPU针对大模型推理的稠密计算优化
+
+地址:https://vllm-ascend.readthedocs.io/zh-cn/latest/user_guide/release_notes.html
+
+#### VLLM_ASCEND_ENABLE_PREFETCH_MLP
+
+使用方式:`export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1`
+
+启用 MLP 层的权重预取机制
+
+
+
+### verl框架参数设置
+
+主要是内存方面的一些设置开关(注意,这个里面的优化都或多或少会导致吞吐量有一定程度的劣化)
+
+~~~bash
+# 梯度检查点 (Gradient Checkpointing)
+# 作用: 通过重新计算激活值来节省显存,以计算换内存。在前向传播时不保存中间激活值,反向传播时重新计算,可以显著降低显存占用,允许使用更大的batch size。
+actor_rollout_ref.model.enable_gradient_checkpointing=True
+
+# 参数卸载 (Parameter Offload)
+# 作用: 将模型参数卸载到CPU内存,训练时再加载回GPU。
+actor_rollout_ref.actor.fsdp_config.param_offload=${offload} # True
+actor_rollout_ref.ref.fsdp_config.param_offload=${offload} # True
+
+# 优化器状态卸载 (Optimizer Offload)
+# 作用: 将优化器状态(如Adam的动量)卸载到CPU。优化器状态通常占用大量显存(对于Adam,每个参数需要额外8字节),卸载可以节省显存。
+actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} # True
+
+# 释放推理引擎缓存 (Free Cache Engine)
+# 作用: 在训练阶段释放推理引擎的KV cache和权重。这是3D-HybridEngine的核心优化,允许在同一GPU上交替进行推理和训练,显著降低显存需求。
+actor_rollout_ref.rollout.free_cache_engine=True
+
+# 熵计算优化
+# entropy_checkpointing: 在训练时对熵计算启用重计算,降低显存峰值
+# entropy_from_logits_with_chunking: 分块处理logits张量(如2048 tokens一组),避免一次性加载整个[bsz*seq_len, vocab]张量
+actor_rollout_ref.actor.entropy_checkpointing=True
+actor_rollout_ref.ref.entropy_checkpointing=True
+actor_rollout_ref.actor.entropy_from_logits_with_chunking=True
+actor_rollout_ref.ref.entropy_from_logits_with_chunking=True
+
+# 推理引擎显存配置
+# gpu_memory_utilization: 控制vLLM使用的GPU显存比例(0.90 = 90%)
+# enforce_eager=False: 启用CUDA graphs加速推理,但会占用额外显存
+actor_rollout_ref.rollout.gpu_memory_utilization=0.90
+actor_rollout_ref.rollout.enforce_eager=False
+~~~
+
+## NPU调优参考文章
+
+环境变量相关:[环境变量列表-Ascend Extension for PyTorch6.0.0-昇腾社区](https://www.hiascend.com/document/detail/zh/Pytorch/600/apiref/Envvariables/Envir_001.html)
+
+社区性能调优教程:[性能调优流程-Ascend Extension for PyTorch6.0.0-昇腾社区](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0001.html)
+
+
+
diff --git a/docs/examples/config.rst b/docs/examples/config.rst
index 9909dd67581..b1053903510 100644
--- a/docs/examples/config.rst
+++ b/docs/examples/config.rst
@@ -130,12 +130,10 @@ Actor/Rollout/Reference Policy
use_kl_loss: False # True for GRPO
# Rollout Correction (corrects distribution mismatch between rollout and training)
rollout_correction:
- rollout_is: token # IS weights: token/sequence/null
+ rollout_is: token # IS weights
rollout_is_threshold: 2.0 # Upper threshold for IS weights
- rollout_rs: null # Rejection sampling: token/sequence/geometric/null
+ rollout_rs: null # Rejection sampling
rollout_rs_threshold: null # RS upper threshold
- rollout_rs_threshold_lower: null # RS lower threshold
- rollout_token_veto_threshold: null # Per-token veto (null to disable)
use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
@@ -534,12 +532,10 @@ Algorithm
target_kl: 0.1
# Rollout Correction
rollout_correction:
- rollout_is: null # IS weights: token/sequence/null
+ rollout_is: null # IS weights
rollout_is_threshold: 2.0 # Upper threshold for IS weights
- rollout_rs: null # Rejection sampling: token/sequence/geometric/null
+ rollout_rs: null # Rejection sampling
rollout_rs_threshold: null # RS upper threshold
- rollout_rs_threshold_lower: null # RS lower threshold
- rollout_token_veto_threshold: null # Per-token veto (null to disable)
- ``gamma``: discount factor
- ``lam``: Trade-off between bias and variance in the GAE estimator
@@ -557,12 +553,10 @@ Algorithm
- ``rollout_correction``: Rollout Correction configuration (nested dict). Set to ``null`` to disable.
When enabled, contains:
- - ``rollout_is``: IS weights aggregation level: ``token``, ``sequence``, or ``null`` to disable IS weights.
+ - ``rollout_is``: IS weights aggregation level, ``null`` to disable IS weights.
- ``rollout_is_threshold``: Upper threshold for IS weights (e.g., 2.0).
- - ``rollout_rs``: Rejection sampling mode: ``token``, ``sequence``, ``geometric``, or ``null`` to disable RS.
+ - ``rollout_rs``: Rejection sampling mode, ``null`` to disable RS.
- ``rollout_rs_threshold``: RS upper threshold.
- - ``rollout_rs_threshold_lower``: RS lower threshold (null = auto-reciprocal).
- - ``rollout_token_veto_threshold``: Per-token veto threshold for catastrophic outliers (null = disabled).
Note: Rollout Correction requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.
diff --git a/docs/examples/gsm8k_example.rst b/docs/examples/gsm8k_example.rst
index bc56497be64..1f5bdde7a22 100644
--- a/docs/examples/gsm8k_example.rst
+++ b/docs/examples/gsm8k_example.rst
@@ -75,7 +75,7 @@ model.
---------------------------------
We provide a SFT Trainer using PyTorch FSDP in
-`fsdp_sft_trainer.py
+
+
+
+
+
+This mode requires the P2P feature of checkpoint_engine. Please ensure you have installed it via pip install 'checkpoint-engine[p2p]' and that your version is 0.4.0 or higher.
+
+In addition, during the installation of checkpoint-engine[p2p], the transfer engine will be installed. However, This library has no prebuilt packages for Ascend devices and must be compiled from source. For detailed compilation instructions, see: [transfer-engine: ascend direct](https://github.com/kvcache-ai/Mooncake/blob/main/docs/source/design/transfer-engine/ascend_direct_transport.md)
### Benchmark
1. benchmark setup
- model: Qwen/Qwen3-30B-A3B-Base
-- trainer: fsdp world_size=2
+- trainer: fsdp world_size=2 (since Ascend 910C has 64GB of HBM, we set world_size=4)
- rollout: num_rollout=30 (only receive weight without cuda ipc to vllm/sglang)
```bash
-python3 tests/checkpoint_engine/test_nixl_checkpoint_engine.py
-python3 tests/checkpoint_engine/test_nccl_checkpoint_engine.py
-python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
+pytest tests/checkpoint_engine/test_correctness_on_gpu.py
+pytest tests/checkpoint_engine/test_correctness_on_npu.py
+pytest tests/checkpoint_engine/test_special_server_adapter.py
```
2. benchmark result
@@ -36,4 +47,5 @@ python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
|----|----|----|----|
|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NCCL | ~7 | 8.25|
|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NIXL | ~7 | 8.25|
-|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3|
\ No newline at end of file
+|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3|
+|2*16 Ascend 910C, inner suppernode| kimi_ckpt_engine | offload: 7 update: 3.5 | 16.5|
diff --git a/verl/checkpoint_engine/__init__.py b/verl/checkpoint_engine/__init__.py
index 4409369e8e8..e0c827aec7d 100644
--- a/verl/checkpoint_engine/__init__.py
+++ b/verl/checkpoint_engine/__init__.py
@@ -44,10 +44,16 @@
except ImportError:
HCCLCheckpointEngine = None
-
try:
from .nixl_checkpoint_engine import NIXLCheckpointEngine
__all__ += ["NIXLCheckpointEngine"]
except ImportError:
NIXLCheckpointEngine = None
+
+try:
+ from .kimi_checkpoint_engine import KIMICheckpointEngine
+
+ __all__ += ["KIMICheckpointEngine"]
+except ImportError:
+ KIMICheckpointEngine = None
diff --git a/verl/checkpoint_engine/base.py b/verl/checkpoint_engine/base.py
index f3a89c67d95..f722c1f4948 100644
--- a/verl/checkpoint_engine/base.py
+++ b/verl/checkpoint_engine/base.py
@@ -23,7 +23,7 @@
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.utils.distributed import initialize_global_process_group_ray
from verl.utils.ray_utils import auto_await
-from verl.workers.config import HFModelConfig, RolloutConfig
+from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig
from verl.workers.rollout import BaseRollout, RolloutReplica, get_rollout_class
@@ -255,20 +255,32 @@ def __init__(
rollout_config: RolloutConfig,
model_config: HFModelConfig,
server_adapter: BaseRollout = None,
+ *args,
+ **kwargs,
) -> None:
+ super().__init__()
self.rollout_config = rollout_config
self.model_config = model_config
+ self.server_adapter: BaseRollout = server_adapter
+ backend = self.rollout_config.checkpoint_engine.backend
+ bucket_size = self.rollout_config.checkpoint_engine.update_weights_bucket_megabytes << 20
+ engine_kwargs = self.rollout_config.checkpoint_engine.engine_kwargs.get(backend, {})
+ self.checkpoint_engine: CheckpointEngine = CheckpointEngineRegistry.new(
+ backend, bucket_size=bucket_size, **engine_kwargs
+ )
+ self.extra_rollout_args = args
+ self.extra_rollout_kwargs = kwargs
+ if self.server_adapter is None:
+ self.server_adapter = get_rollout_class(self.rollout_config.name, self.rollout_config.mode)(
+ *self.extra_rollout_args,
+ config=self.rollout_config,
+ model_config=self.model_config,
+ device_mesh=None,
+ **self.extra_rollout_kwargs,
+ )
# sglang and trt-llm need device_mesh for internal communication
initialize_global_process_group_ray(timeout_second=None, backend="cpu:gloo")
- self.server_adapter: BaseRollout = server_adapter or get_rollout_class(
- rollout_config.name, rollout_config.mode
- )(config=rollout_config, model_config=model_config, device_mesh=None)
-
- backend = rollout_config.checkpoint_engine.backend
- bucket_size = rollout_config.checkpoint_engine.update_weights_bucket_megabytes << 20
- engine_kwargs = rollout_config.checkpoint_engine.engine_kwargs.get(backend, {})
- self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs)
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
async def update_weights(self):
@@ -279,6 +291,16 @@ async def update_weights(self):
def execute_checkpoint_engine(self, method: str, *args, **kwargs):
return getattr(self.checkpoint_engine, method)(*args, **kwargs)
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def get_replica_rank(self) -> int:
+ """Get replica rank from the underlying rollout server adapter."""
+ return self.server_adapter.replica_rank
+
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def is_leader_rank(self) -> bool:
+ """Get leader rank flag from the underlying rollout server adapter."""
+ return self.server_adapter.is_leader_rank
+
_worker_cls = ray.remote(CheckpointEngineWorker)
@@ -307,19 +329,20 @@ class CheckpointEngineManager:
```
Args:
- backend: The checkpoint engine backend.
+ config: The checkpoint engine config.
trainer: The trainer worker group.
replicas: The list of rollout replicas.
"""
def __init__(
self,
- backend: str,
+ config: CheckpointEngineConfig,
trainer: RayWorkerGroup,
replicas: list[RolloutReplica],
) -> None:
- self.backend = backend
- self.backend_cls = CheckpointEngineRegistry.get(backend)
+ self.config = config
+ self.backend = config.backend
+ self.backend_cls = CheckpointEngineRegistry.get(config.backend)
self.trainer = trainer
self.replicas = replicas
diff --git a/verl/checkpoint_engine/hccl_checkpoint_engine.py b/verl/checkpoint_engine/hccl_checkpoint_engine.py
index 18366d6b2a2..c4839999ddf 100644
--- a/verl/checkpoint_engine/hccl_checkpoint_engine.py
+++ b/verl/checkpoint_engine/hccl_checkpoint_engine.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import asyncio
import logging
import os
import time
@@ -67,8 +66,7 @@ def __init__(
self.socket = socket
self.topic = topic
- loop = asyncio.get_running_loop()
- self._task = loop.run_in_executor(None, self._run)
+ self._run()
def _run(self):
# broadcast tensor meta via zeromq PUB/SUB
@@ -88,7 +86,6 @@ async def wait_for_complete(self) -> dict[str, TensorMeta]:
Returns:
dict[str, TensorMeta]: The bucket meta after broadcast.
"""
- await self._task
return self.metadata
@@ -148,6 +145,7 @@ def finalize(self):
self.send_buf = None
self.recv_buf = None
+ torch.npu.empty_cache()
@classmethod
def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]):
@@ -165,7 +163,7 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada
def _start_zmq_server(self):
self.ip = ray.util.get_node_ip_address().strip("[]")
- self.zmq_port, self.listen_sock = get_free_port(self.ip)
+ self.zmq_port, _ = get_free_port(self.ip)
context = zmq.Context()
self.socket = context.socket(zmq.PUB)
diff --git a/verl/checkpoint_engine/kimi_checkpoint_engine.py b/verl/checkpoint_engine/kimi_checkpoint_engine.py
new file mode 100644
index 00000000000..f042c3489d8
--- /dev/null
+++ b/verl/checkpoint_engine/kimi_checkpoint_engine.py
@@ -0,0 +1,384 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import asyncio
+import concurrent.futures
+import logging
+import os
+import time
+import types
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import AsyncGenerator, Generator
+
+import checkpoint_engine.distributed as dist
+import ray
+import torch
+from checkpoint_engine.ps import H2DBucket, ParameterMeta, ParameterServer, _gen_h2d_buckets, _to_named_tensor
+
+from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry
+from verl.utils.device import get_nccl_backend, get_torch_device
+from verl.utils.net_utils import get_free_port
+
+logger = logging.getLogger(__name__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+
+
+def ckpt_get_named_tensor_buckets(
+ iterable: Generator[tuple[str, torch.Tensor], None, None],
+ bucket_bytes: int,
+ world_size: int,
+ rank_id: int,
+ rollout_dtype: torch.dtype = torch.bfloat16,
+) -> dict[str, torch.Tensor]:
+ if bucket_bytes <= 0:
+ raise ValueError(f"bucket_bytes must be greater than 0, got {bucket_bytes}")
+
+ current_bucket = {}
+ current_size = 0
+ for tensor_idx, (name, tensor) in enumerate(iterable):
+ tensor = tensor.to(rollout_dtype)
+ if tensor_idx % world_size == rank_id:
+ tensor_size = tensor.element_size() * tensor.numel()
+ if current_size + tensor_size > bucket_bytes:
+ if current_bucket:
+ yield current_bucket
+ current_bucket = {}
+ current_size = 0
+
+ current_bucket[name] = tensor
+ current_size += tensor_size
+
+ if current_bucket:
+ yield current_bucket
+
+
+async def receive_tensor(
+ self,
+ checkpoint_name: str,
+ ranks_group: int,
+ ranks: list[int] | None = None,
+ bucket_size: int = 2 << 30,
+ disable_h2d_buffer: bool = False,
+) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
+ assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
+ assert dist.is_initialized(), "process group is not initialized"
+ assert self._p2p_store is not None, "p2p store is not initialized"
+ assert ranks, "ranks should be set"
+
+ # first execute a barrier to avoid subsequent device oom
+ dist.barrier(group=ranks_group)
+ buckets = _gen_h2d_buckets(
+ self._current_global_parameter_metas,
+ bucket_size,
+ self._local_rdma_devices,
+ self._remote_rdma_devices,
+ ranks,
+ )
+ h2d_buffer: torch.Tensor | None = (
+ None
+ if disable_h2d_buffer
+ else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
+ )
+ # p2p store need to register h2d_buffer to let other ranks read
+ if ranks:
+ h2d_buffer_name = "__h2d_buffer__"
+ if h2d_buffer is not None and self._p2p_store is not None:
+ self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
+ receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
+ for receiver_rank, owner_rank, bucket in buckets:
+ if receiver_rank != self._rank:
+ continue
+ receiver_rank_buckets.append((owner_rank, bucket))
+ buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type)
+ buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
+
+ max_len = 0
+ for receiver_rank, _, bucket in buckets:
+ buckets_by_receiver_rank[receiver_rank].append(bucket)
+ if len(buckets_by_receiver_rank[receiver_rank]) > max_len:
+ max_len = len(buckets_by_receiver_rank[receiver_rank])
+ gidx = 0
+ metadata: list[ParameterMeta]
+ try:
+ for i in range(max_len):
+ if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
+ self._copy_to_buffer(
+ checkpoint_name,
+ receiver_rank_buckets[i][1],
+ h2d_buffer,
+ receiver_rank_buckets[i][0] if ranks else None,
+ )
+ for receiver_rank, _buckets in buckets_by_receiver_rank.items():
+ if i >= len(_buckets):
+ continue
+ bucket = _buckets[i]
+ start = gidx % 2 * bucket_size
+ buffer_b: torch.Tensor = buffer[start : start + bucket.size]
+ if receiver_rank == self._rank:
+ if disable_h2d_buffer:
+ self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
+ else:
+ buffer_b.data.copy_(h2d_buffer[: bucket.size])
+ broadcast_op = BroadcastOperation(
+ rank=receiver_rank,
+ ranks_group=ranks_group,
+ bucket=buffer_b,
+ metadata=bucket.items,
+ )
+ if gidx == 0:
+ metadata = await broadcast_op.wait_for_complete()
+ gidx += 1
+ continue
+ meta_list = _to_named_tensor(metadata, (gidx - 1) % 2 * bucket_size)
+ for item in meta_list:
+ shape = item["shape"]
+ if isinstance(shape, list | tuple):
+ shape = torch.Size(shape)
+ assert isinstance(shape, torch.Size)
+ dtype, offset = item["dtype"], item["offset"]
+ size = dtype.itemsize * shape.numel()
+ tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
+ yield item["name"], tensor
+ metadata = await broadcast_op.wait_for_complete()
+ self.device_manager.device_module.synchronize()
+ gidx += 1
+
+ meta_list = _to_named_tensor(metadata, (gidx - 1) % 2 * bucket_size)
+ for item in meta_list:
+ shape = item["shape"]
+ if isinstance(shape, list | tuple):
+ shape = torch.Size(shape)
+ assert isinstance(shape, torch.Size)
+ dtype, offset = item["dtype"], item["offset"]
+ size = dtype.itemsize * shape.numel()
+ tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
+ yield item["name"], tensor
+
+ finally:
+ dist.barrier(group=ranks_group)
+ if ranks and h2d_buffer is not None:
+ self._p2p_store.unregister_named_tensors([h2d_buffer_name])
+ self.device_manager.device_module.empty_cache()
+
+
+@dataclass
+class MasterMetadata:
+ zmq_ip: str
+ zmq_port: int
+ dist_ip: str
+ dist_port: int
+
+
+class BroadcastOperation:
+ """Async broadcast operation with NCCL in separate thread.
+
+ Args:
+ rank (int): The rank of the current process.
+ ranks_group (int): The process group's value.
+ bucket (torch.Tensor): The tensor to broadcast.
+ metadata (list[ParameterMeta]): The metadata of the tensor.
+ """
+
+ def __init__(
+ self,
+ rank: int,
+ ranks_group: int,
+ bucket: torch.Tensor,
+ metadata: list[ParameterMeta],
+ ) -> None:
+ self.rank = rank
+ self.ranks_group = ranks_group
+ self.bucket = bucket
+ self.metadata = metadata
+
+ loop = asyncio.get_running_loop()
+ self._task = loop.run_in_executor(None, self._run)
+
+ def _run(self):
+ # broadcast tensor
+ dist.broadcast(self.bucket, src=self.rank, group=self.ranks_group)
+
+ async def wait_for_complete(self) -> list[ParameterMeta]:
+ """Wait for the broadcast operation to complete.
+
+ Returns:
+ list[ParameterMeta]: The bucket meta after broadcast.
+ """
+ await self._task
+ return self.metadata
+
+
+@CheckpointEngineRegistry.register("kimi_ckpt_engine")
+class KIMICheckpointEngine(CheckpointEngine):
+ """NCCL checkpoint engine with collective communication.
+
+ Args:
+ bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use
+ two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size.
+ rebuild_group (bool): Whether to rebuild the NCCL process group in each update. Defaults to False.
+ is_master (bool): Whether the current process is the master process. Defaults to False.
+ rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16.
+ """
+
+ def __init__(
+ self,
+ bucket_size: int,
+ rebuild_group: bool = False,
+ is_master: bool = False,
+ rollout_dtype: torch.dtype = torch.bfloat16,
+ ) -> None:
+ self.bucket_size = bucket_size
+ self.rebuild_group = rebuild_group
+ self.rollout_dtype = rollout_dtype
+ self.is_master = is_master
+ self.initialized = False
+ self.checkpoint_name = "kimi_checkpoint_engine"
+
+ def prepare(self) -> MasterMetadata:
+ if self.is_master:
+ self.ip = ray.util.get_node_ip_address().strip("[]")
+ self.listen_port, _ = get_free_port(self.ip)
+
+ return (
+ MasterMetadata(zmq_ip=None, zmq_port=None, dist_ip=self.ip, dist_port=self.listen_port)
+ if self.is_master
+ else None
+ )
+
+ def finalize(self):
+ """Destroy the ckpt engine process group if rebuild_group is True."""
+ if self.rebuild_group:
+ dist.destroy_process_group()
+ self.rank = None
+ self.world_size = None
+ self.initialized = False
+
+ @classmethod
+ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]):
+ trainer_kwargs = {
+ "method": ["init_process_group"] * trainer_world_size,
+ "rank": list(range(0, trainer_world_size)),
+ "trainer_world_size": [trainer_world_size] * trainer_world_size,
+ "rollout_world_size": [rollout_world_size] * trainer_world_size,
+ "master_metadata": [metadata[0]] * trainer_world_size,
+ }
+ rollout_kwargs = {
+ "method": ["init_process_group"] * rollout_world_size,
+ "rank": list(range(trainer_world_size, trainer_world_size + rollout_world_size)),
+ "trainer_world_size": [trainer_world_size] * rollout_world_size,
+ "rollout_world_size": [rollout_world_size] * rollout_world_size,
+ "master_metadata": [metadata[0]] * rollout_world_size,
+ }
+ return trainer_kwargs, rollout_kwargs
+
+ def init_process_group(
+ self,
+ rank: int,
+ trainer_world_size: int,
+ rollout_world_size: int,
+ master_metadata: MasterMetadata,
+ ):
+ """Initialize the ckpt engine process group.
+
+ Args:
+ rank (int): The rank of the current process.
+ world_size (int): The total number of processes.
+ """
+ self.rank = rank
+ self.trainer_world_size = trainer_world_size
+ self.rollout_world_size = rollout_world_size
+ self.world_size = trainer_world_size + rollout_world_size
+
+ if not self.initialized:
+ self.parameter_server = ParameterServer(
+ rank=rank,
+ world_size=self.world_size,
+ auto_pg=False,
+ master_addr=master_metadata.dist_ip,
+ master_port=master_metadata.dist_port,
+ )
+ self.parameter_server.receive_tensor = types.MethodType(receive_tensor, self.parameter_server)
+
+ dist.use_backend(f"vllm_{get_nccl_backend()}")
+ self.parameter_server.init_process_group()
+
+ self.rollout_ranks = list(range(self.trainer_world_size, self.world_size))
+ self.rollout_group = dist.new_group(self.rollout_ranks)
+ self.initialized = True
+
+ @torch.no_grad()
+ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]):
+ """Send the weights of the model.
+
+ Args:
+ weights: A generator that yields the name of the weight tensor and the tensor itself.
+ """
+
+ def offload_cpu(name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
+ return name, tensor.to("cpu", non_blocking=True)
+
+ start_time = time.time()
+ named_tensors = {}
+ for named_tensors_gpu in ckpt_get_named_tensor_buckets(
+ weights, self.bucket_size, self.trainer_world_size, self.rank, self.rollout_dtype
+ ):
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
+ futures = [
+ executor.submit(
+ offload_cpu,
+ name,
+ tensor,
+ )
+ for name, tensor in named_tensors_gpu.items()
+ ]
+ for future in concurrent.futures.as_completed(futures):
+ name, tensor_cpu = future.result()
+ named_tensors[name] = tensor_cpu
+
+ get_torch_device().synchronize()
+
+ self.parameter_server.register_checkpoint(self.checkpoint_name, named_tensors=named_tensors)
+ named_tensors = {}
+ get_torch_device().empty_cache()
+ logger.info(f"Rank {self.rank} offload and register, time cost: {time.time() - start_time:.2f}s")
+
+ self.parameter_server.gather_metas(self.checkpoint_name)
+ dist.barrier()
+ self.parameter_server.unregister_checkpoint(self.checkpoint_name)
+ logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s")
+
+ @torch.no_grad()
+ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
+ """Receive the weights of the model.
+
+ Yields:
+ A tuple of the name of the weight tensor and the tensor itself.
+ """
+ self.parameter_server.gather_metas(self.checkpoint_name)
+
+ start_time = time.time()
+ total_bytes, total_params = 0, 0
+ async for name, tensor in self.parameter_server.receive_tensor(
+ self.checkpoint_name, self.rollout_group, self.rollout_ranks, self.bucket_size
+ ):
+ total_bytes += tensor.element_size() * tensor.nelement()
+ total_params += 1
+ yield name, tensor
+ dist.barrier()
+ time_cost = time.time() - start_time
+ bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024)
+ logger.info(
+ f"Rank {self.rank} receive weights done, total_params: {total_params}, "
+ f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s"
+ )
diff --git a/verl/checkpoint_engine/nccl_checkpoint_engine.py b/verl/checkpoint_engine/nccl_checkpoint_engine.py
index e3f8d99447a..279733900d6 100644
--- a/verl/checkpoint_engine/nccl_checkpoint_engine.py
+++ b/verl/checkpoint_engine/nccl_checkpoint_engine.py
@@ -148,6 +148,8 @@ def finalize(self):
self.send_buf = None
self.recv_buf = None
+ torch.cuda.empty_cache()
+
@classmethod
def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]):
trainer_kwargs = {
@@ -164,7 +166,7 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada
def _start_zmq_server(self):
self.ip = ray.util.get_node_ip_address().strip("[]")
- self.listen_port, self.listen_sock = get_free_port(self.ip)
+ self.listen_port, _ = get_free_port(self.ip)
context = zmq.Context()
self.socket = context.socket(zmq.PUB)
diff --git a/verl/checkpoint_engine/nixl_checkpoint_engine.py b/verl/checkpoint_engine/nixl_checkpoint_engine.py
index 5fd8c44f509..fbdefc5b230 100644
--- a/verl/checkpoint_engine/nixl_checkpoint_engine.py
+++ b/verl/checkpoint_engine/nixl_checkpoint_engine.py
@@ -82,7 +82,7 @@ def get_agent_metadata(self) -> NixlAgentMetadata:
def start_zmq_server(self):
self.ip = ray.util.get_node_ip_address().strip("[]")
- self.listen_port, self.listen_sock = get_free_port(self.ip)
+ self.listen_port, _ = get_free_port(self.ip)
context = zmq.asyncio.Context()
self.socket = context.socket(zmq.PULL)
diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py
index 228d2248b7e..6dd3871c34e 100644
--- a/verl/experimental/agent_loop/agent_loop.py
+++ b/verl/experimental/agent_loop/agent_loop.py
@@ -35,24 +35,31 @@
from verl.experimental.agent_loop.utils import resolve_config_path
from verl.protocol import DataProto
from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup
-from verl.utils import hf_processor, hf_tokenizer
from verl.utils.chat_template import initialize_system_prompt
+from verl.utils.config import omega_conf_to_dataclass
from verl.utils.dataset.rl_dataset import RLHFDataset, get_dataset_class
-from verl.utils.fs import copy_to_local
from verl.utils.model import compute_position_id_with_mask
-from verl.utils.ray_utils import get_event_loop
+from verl.utils.ray_utils import auto_await, get_event_loop
from verl.utils.rollout_trace import (
RolloutTraceConfig,
rollout_trace_attr,
rollout_trace_op,
)
-from verl.utils.transferqueue_utils import tqbridge
+from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import TokenOutput, get_rollout_replica_class
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+def _get_rollout_and_model_config(config: DictConfig) -> tuple[DictConfig, DictConfig]:
+ # TODO: backward compatibility, remove this once we switch to new trainer.
+ if config.get("actor_rollout_ref"):
+ return config.actor_rollout_ref.rollout, config.actor_rollout_ref.model
+ else:
+ return config.rollout, config.model
+
+
class AsyncLLMServerManager:
"""
A class to manage multiple OpenAI compatible LLM servers. This class provides
@@ -64,7 +71,7 @@ def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandl
"""Initialize the AsyncLLMServerManager.
Args:
- config (DictConfig): YAML config.
+ config (DictConfig): whole config for main entrypoint.
server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000.
"""
@@ -190,7 +197,16 @@ def __init__(self, config: DictConfig):
class AgentLoopBase(ABC):
"""An agent loop takes an input message, chat with OpenAI compatible LLM server and interact with various
- environments."""
+ environments.
+
+ Args:
+ trainer_config (DictConfig): whole config for main entrypoint.
+ server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager.
+ tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
+ processor (AutoProcessor): Processor for process messages.
+ dataset_cls (type[Dataset]): Dataset class for creating dataset, Defaults to RLHFDataset.
+ data_config (DictConfigWrap): Dataset config.
+ """
def __init__(
self,
@@ -199,26 +215,17 @@ def __init__(
tokenizer: AutoTokenizer,
processor: AutoProcessor,
dataset_cls: type[RLHFDataset],
- dataset_config: DictConfigWrap,
+ data_config: DictConfigWrap,
**kwargs,
):
- """Initialize agent loop, each sample will have its own loop instance.
-
- Args:
- trainer_config (DictConfigWrap): trainer config.
- server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager.
- tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
- processor (AutoProcessor): Processor for process messages.
- dataset_cls (type[Dataset]): Dataset class for creating dataset, Defaults to RLHFDataset.
- dataset_config (DictConfigWrap): Dataset config.
- """
self.config = trainer_config.config
+ self.rollout_config, _ = _get_rollout_and_model_config(self.config)
self.server_manager = server_manager
self.tokenizer = tokenizer
self.processor = processor
self.dataset_cls = dataset_cls
- self.dataset_config = dataset_config.config
- self.apply_chat_template_kwargs = self.dataset_config.get("apply_chat_template_kwargs", {})
+ self.data_config = data_config.config
+ self.apply_chat_template_kwargs = self.data_config.get("apply_chat_template_kwargs", {})
self.system_prompt = initialize_system_prompt(self.tokenizer, **self.apply_chat_template_kwargs)
self.loop = get_event_loop()
@@ -234,7 +241,7 @@ async def process_vision_info(self, messages: list[dict]) -> dict:
multi_modal_data = {}
if self.processor is not None:
images, videos = await self.dataset_cls.process_vision_info(
- messages, image_patch_size=self.processor.image_processor.patch_size, config=self.dataset_config
+ messages, image_patch_size=self.processor.image_processor.patch_size, config=self.data_config
)
if images is not None:
multi_modal_data["images"] = images
@@ -342,7 +349,13 @@ def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]:
class AgentLoopWorker:
- """Agent loop worker takes a batch of messages and run each message in an agent loop."""
+ """Agent loop worker takes a batch of messages and run each message in an agent loop.
+
+ Args:
+ config (DictConfig): whole config for main entrypoint.
+ server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
+ reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation.
+ """
def __init__(
self,
@@ -350,13 +363,10 @@ def __init__(
server_handles: list[ray.actor.ActorHandle],
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
- """Initialize agent loop manager.
- Args:
- config (DictConfig): YAML config.
- server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
- reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation.
- """
self.config = config
+ rollout_config, model_config = _get_rollout_and_model_config(config)
+ self.rollout_config: RolloutConfig = omega_conf_to_dataclass(rollout_config)
+ self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config)
# for recipe to change
if not hasattr(self, "server_manager"):
@@ -365,33 +375,29 @@ def __init__(
self.dataset_cls = get_dataset_class(config.data)
self.reward_loop_worker_handles = reward_loop_worker_handles
- model_path = config.actor_rollout_ref.model.path
- self.model_name = "/".join(model_path.split("/")[-2:])
- local_path = copy_to_local(config.actor_rollout_ref.model.path)
- self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)
- self.processor = hf_processor(local_path, trust_remote_code=True)
+ self.tokenizer = self.model_config.tokenizer
+ self.processor = self.model_config.processor
- agent_loop_config_path = config.actor_rollout_ref.rollout.agent.agent_loop_config_path
+ agent_loop_config_path = self.rollout_config.agent.agent_loop_config_path
if agent_loop_config_path:
resolved_path = resolve_config_path(agent_loop_config_path)
agent_loop_configs = OmegaConf.load(resolved_path)
for agent_loop_config in agent_loop_configs:
_agent_loop_registry[agent_loop_config.name] = agent_loop_config
- if self.config.actor_rollout_ref.model.get("custom_chat_template", None) is not None:
- if self.processor is not None:
- self.processor.chat_template = self.config.actor_rollout_ref.model.custom_chat_template
- self.tokenizer.chat_template = self.config.actor_rollout_ref.model.custom_chat_template
+ if self.model_config.get("custom_chat_template", None) is not None:
+ if self.model_config.processor is not None:
+ self.model_config.processor.chat_template = self.model_config.custom_chat_template
+ self.model_config.tokenizer.chat_template = self.model_config.custom_chat_template
- trace_config = self.config.actor_rollout_ref.rollout.get("trace", {})
+ trace_config = self.rollout_config.trace
RolloutTraceConfig.init(
- self.config.trainer.project_name,
- self.config.trainer.experiment_name,
+ self.rollout_config.trace.project_name,
+ self.rollout_config.trace.experiment_name,
trace_config.get("backend"),
trace_config.get("token2text", False),
trace_config.get("max_samples_per_step_per_worker", None),
)
- @tqbridge()
async def generate_sequences(self, batch: DataProto) -> DataProto:
"""Generate sequences from agent loop.
@@ -413,7 +419,7 @@ async def generate_sequences(self, batch: DataProto) -> DataProto:
responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|
response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|
"""
- config = self.config.actor_rollout_ref.rollout
+ config = self.rollout_config
sampling_params = dict(
temperature=config.temperature,
top_p=config.top_p,
@@ -502,7 +508,7 @@ async def _run_agent_loop(
tokenizer=self.tokenizer,
processor=self.processor,
dataset_cls=self.dataset_cls,
- dataset_config=DictConfigWrap(self.config.data),
+ data_config=DictConfigWrap(self.config.data),
)
output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs)
return await self._agent_loop_postprocess(output, **kwargs)
@@ -536,7 +542,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
prompt_output = self.tokenizer.pad(
{"input_ids": output.prompt_ids},
padding="max_length",
- max_length=self.config.actor_rollout_ref.rollout.prompt_length,
+ max_length=self.rollout_config.prompt_length,
return_tensors="pt",
return_attention_mask=True,
)
@@ -548,7 +554,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
response_output = self.tokenizer.pad(
{"input_ids": output.response_ids},
padding="max_length",
- max_length=self.config.actor_rollout_ref.rollout.response_length,
+ max_length=self.rollout_config.response_length,
return_tensors="pt",
return_attention_mask=True,
)
@@ -559,7 +565,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
response_mask_output = self.tokenizer.pad(
{"input_ids": output.response_mask},
padding="max_length",
- max_length=self.config.actor_rollout_ref.rollout.response_length,
+ max_length=self.rollout_config.response_length,
return_tensors="pt",
return_attention_mask=False,
)
@@ -568,7 +574,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
response_logprobs = None
if output.response_logprobs is not None:
- pad_size = self.config.actor_rollout_ref.rollout.response_length - len(output.response_logprobs)
+ pad_size = self.rollout_config.response_length - len(output.response_logprobs)
response_logprobs = torch.tensor(output.response_logprobs + [0.0] * pad_size).unsqueeze(0)
response_mask = response_mask_output["input_ids"] * response_output["attention_mask"]
@@ -777,8 +783,17 @@ def _postprocess(
metrics = [input.metrics.model_dump() for input in inputs]
# Collect extra fields from all inputs and convert them to np.ndarray
+ # Keep a stable set of keys so downstream batch concat stays consistent across agent loops.
extra_fields = {}
- all_keys = set(key for input_item in inputs for key in input_item.extra_fields)
+ default_extra_keys = {
+ "turn_scores",
+ "tool_rewards",
+ "is_cancel",
+ "param_version_start",
+ "param_version_end",
+ "extras",
+ }
+ all_keys = set(key for input_item in inputs for key in input_item.extra_fields) | default_extra_keys
for key in all_keys:
temp_arr = np.empty(len(inputs), dtype=object)
temp_arr[:] = [input.extra_fields.get(key) for input in inputs]
@@ -799,20 +814,6 @@ def _postprocess(
meta_info=meta_info,
)
- def create_transferqueue_client(
- self,
- ):
- """Create a client for data system (TransferQueue)."""
- from verl.single_controller.ray.base import get_random_string
- from verl.utils.transferqueue_utils import create_transferqueue_client
-
- client_name = get_random_string(length=6)
-
- self.tq_client = create_transferqueue_client(
- client_id=f"AgentLoopWorker_{client_name}",
- config=self.config.transfer_queue,
- )
-
async def get_trajectory_info(step, index, validate):
"""Get trajectory info.
@@ -837,7 +838,17 @@ async def get_trajectory_info(step, index, validate):
class AgentLoopManager:
- """Agent loop manager that manages a group of agent loop workers."""
+ """Agent loop manager that manages a group of agent loop workers.
+
+ - if worker_group is not None, rollout server is in hybrid mode, share GPUs with training engine.
+ - otherwise, rollout server is in standalone mode, use separate GPUs, e.g., one-step-off/fully async training.
+
+ Args:
+ config (DictConfig): whole config for main entrypoint.
+ worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode.
+ rollout_resource_pool (RayResourcePool): Resource pool for hybrid mode, only used by TensorRT-LLM.
+ reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation.
+ """
def __init__(
self,
@@ -846,63 +857,70 @@ def __init__(
rollout_resource_pool: RayResourcePool = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
- """Initialize agent loop manager.
-
- Args:
- config (DictConfig): trainer config.
- worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode.
- rollout_resource_pool (RayResourcePool): Resource pool for actor rollout (Colocate or Standalone mode).
- reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation.
- """
self.config = config
+ self.rollout_config, self.model_config = _get_rollout_and_model_config(config)
self.worker_group = worker_group
+ self.rollout_resource_pool = rollout_resource_pool
self.reward_loop_worker_handles = reward_loop_worker_handles
+ assert worker_group is not None or self.rollout_config.nnodes > 0, "nnodes must be > 0 in standalone mode"
+
# for recipe to change
if not hasattr(self, "rollout_replica_class"):
- self.rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name)
+ self.rollout_replica_class = get_rollout_replica_class(self.rollout_config.name)
if not hasattr(self, "agent_loop_workers_class"):
self.agent_loop_workers_class = ray.remote(AgentLoopWorker)
- self._initialize_llm_servers(rollout_resource_pool)
- self._init_agent_loop_workers()
+ @classmethod
+ @auto_await
+ async def create(
+ cls,
+ config: DictConfig,
+ worker_group: RayWorkerGroup = None,
+ rollout_resource_pool: RayResourcePool = None,
+ reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
+ ):
+ """Create agent loop manager."""
+ instance = cls(config, worker_group, rollout_resource_pool, reward_loop_worker_handles)
+ await instance._initialize_llm_servers()
+ await instance._init_agent_loop_workers()
+ return instance
- def _initialize_llm_servers(self, rollout_resource_pool: RayResourcePool):
+ async def _initialize_llm_servers(self):
rollout_world_size = (
- self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
- * self.config.actor_rollout_ref.rollout.data_parallel_size
- * self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size
+ self.rollout_config.tensor_model_parallel_size
+ * self.rollout_config.data_parallel_size
+ * self.rollout_config.pipeline_model_parallel_size
)
world_size = (
self.worker_group.world_size
if self.worker_group
- else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes
+ else self.rollout_config.n_gpus_per_node * self.rollout_config.nnodes
)
num_replicas = world_size // rollout_world_size
- rollout_config = self.config.actor_rollout_ref.rollout
- model_config = self.config.actor_rollout_ref.model
self.rollout_replicas = [
self.rollout_replica_class(
replica_rank=replica_rank,
- config=rollout_config,
- model_config=model_config,
- gpus_per_node=self.config.trainer.n_gpus_per_node,
+ config=self.rollout_config,
+ model_config=self.model_config,
+ gpus_per_node=self.rollout_config.n_gpus_per_node,
)
for replica_rank in range(num_replicas)
]
- if self.worker_group and rollout_config.name != "trtllm":
- self._run_all([server.init_hybrid(self.worker_group) for server in self.rollout_replicas])
- elif self.worker_group and rollout_config.name == "trtllm":
- self._run_all(
- [
- server.init_hybrid_colocated(self.worker_group, rollout_resource_pool)
+ if self.worker_group and self.rollout_config.name != "trtllm":
+ await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas])
+ # TODO: unify trtllm to init_hybrid
+ elif self.worker_group and self.rollout_config.name == "trtllm":
+ await asyncio.gather(
+ *[
+ server.init_hybrid_colocated(self.worker_group, self.rollout_resource_pool)
for server in self.rollout_replicas
]
)
else:
- self._run_all([server.init_standalone() for server in self.rollout_replicas])
+ await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas])
self.server_handles = [server._server_handle for server in self.rollout_replicas]
self.server_addresses = [server._server_address for server in self.rollout_replicas]
@@ -910,14 +928,14 @@ def _initialize_llm_servers(self, rollout_resource_pool: RayResourcePool):
print(f"AgentLoopManager: {self.server_addresses}")
# Update Prometheus configuration with server addresses
- if rollout_config.prometheus.enable:
- if rollout_config.disable_log_stats:
+ if self.rollout_config.prometheus.enable:
+ if self.rollout_config.disable_log_stats:
raise ValueError("PROMETHEUS needs disable_log_stats==False, but it is currently True.")
- update_prometheus_config(rollout_config.prometheus, self.server_addresses, rollout_config.name)
+ update_prometheus_config(self.rollout_config.prometheus, self.server_addresses, self.rollout_config.name)
- def _init_agent_loop_workers(self):
+ async def _init_agent_loop_workers(self):
self.agent_loop_workers = []
- num_workers = self.config.actor_rollout_ref.rollout.agent.num_workers
+ num_workers = self.rollout_config.agent.num_workers
node_ids = [node["NodeID"] for node in ray.nodes() if node["Alive"] and node["Resources"].get("CPU", 0) > 0]
for i in range(num_workers):
@@ -932,7 +950,8 @@ def _init_agent_loop_workers(self):
).remote(self.config, self.server_handles, self.reward_loop_worker_handles)
)
- def generate_sequences(self, prompts: DataProto) -> DataProto:
+ @auto_await
+ async def generate_sequences(self, prompts: DataProto) -> DataProto:
"""Split input batch and dispatch to agent loop workers.
Args:
@@ -943,8 +962,8 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
"""
chunkes = prompts.chunk(len(self.agent_loop_workers))
- outputs = ray.get(
- [
+ outputs = await asyncio.gather(
+ *[
worker.generate_sequences.remote(chunk)
for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)
]
@@ -985,20 +1004,17 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data
return timing
- def clear_kv_cache(self):
+ @auto_await
+ async def clear_kv_cache(self):
"""Clear all rollout kv cache, but don`t sleep."""
- self._run_all([replica.clear_kv_cache() for replica in self.rollout_replicas])
+ await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
- def start_profile(self, **kwargs):
+ @auto_await
+ async def start_profile(self, **kwargs):
"""Start profiling on all rollout replicas."""
- self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas])
+ await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas])
- def stop_profile(self):
+ @auto_await
+ async def stop_profile(self):
"""Stop profiling on all rollout replicas."""
- self._run_all([replica.stop_profile() for replica in self.rollout_replicas])
-
- def _run_all(self, tasks: list[asyncio.Task]):
- async def run_all():
- await asyncio.gather(*tasks)
-
- asyncio.run(run_all())
+ await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas])
diff --git a/verl/experimental/agent_loop/single_turn_agent_loop.py b/verl/experimental/agent_loop/single_turn_agent_loop.py
index 7c479362aa4..6ad3aa429b3 100644
--- a/verl/experimental/agent_loop/single_turn_agent_loop.py
+++ b/verl/experimental/agent_loop/single_turn_agent_loop.py
@@ -17,7 +17,6 @@
from uuid import uuid4
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register
-from verl.tools.utils.tool_registry import initialize_tools_from_config
from verl.utils.profiler import simple_timer
logger = logging.getLogger(__file__)
@@ -30,12 +29,8 @@ class SingleTurnAgentLoop(AgentLoopBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
- self.response_length = self.config.actor_rollout_ref.rollout.response_length
-
- tool_config_path = self.config.data.tool_config_path
- tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
- self.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
+ self.prompt_length = self.rollout_config.prompt_length
+ self.response_length = self.rollout_config.response_length
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
messages = list(kwargs["raw_prompt"])
@@ -48,7 +43,6 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
# 2. apply chat template and tokenize
prompt_ids = await self.apply_chat_template(
messages,
- tools=self.tool_schemas,
images=images,
videos=videos,
)
@@ -81,4 +75,8 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
num_turns=2,
metrics=metrics,
)
+
+ # keeping the schema consistent with tool_agent_loop
+ output.extra_fields.update({"turn_scores": [], "tool_rewards": []})
+
return output
diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py
index f98485a6781..c649a2fc3fd 100644
--- a/verl/experimental/agent_loop/tool_agent_loop.py
+++ b/verl/experimental/agent_loop/tool_agent_loop.py
@@ -21,13 +21,10 @@
import torch
from PIL import Image
-from transformers import AutoProcessor, AutoTokenizer
from verl.experimental.agent_loop.agent_loop import (
AgentLoopBase,
AgentLoopOutput,
- AsyncLLMServerManager,
- DictConfigWrap,
register,
)
from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser
@@ -88,43 +85,35 @@ def __init__(
# Temporary state for tool calls
self.tool_calls: list[FunctionCall] = []
+ self.routed_experts = None
+
# Extra fields for dynamic addition, e.g., tool session data
self.extra_fields: dict[str, Any] = {}
@register("tool_agent")
class ToolAgentLoop(AgentLoopBase):
- def __init__(
- self,
- trainer_config: DictConfigWrap,
- server_manager: AsyncLLMServerManager,
- tokenizer: AutoTokenizer,
- processor: AutoProcessor,
- **kwargs,
- ):
- super().__init__(trainer_config, server_manager, tokenizer, processor, **kwargs)
- config = trainer_config.config
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
# Initialize tools from config file
- self.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns
- self.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns
- self.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls
- self.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length
- self.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side
- tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path
+ self.max_user_turns = self.rollout_config.multi_turn.max_user_turns
+ self.max_assistant_turns = self.rollout_config.multi_turn.max_assistant_turns
+ self.max_parallel_calls = self.rollout_config.multi_turn.max_parallel_calls
+ self.max_tool_response_length = self.rollout_config.multi_turn.max_tool_response_length
+ self.tool_response_truncate_side = self.rollout_config.multi_turn.tool_response_truncate_side
+ tool_config_path = self.rollout_config.multi_turn.tool_config_path
tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
self.tools = {tool.name: tool for tool in tool_list}
self.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
- self.tool_parser = ToolParser.get_tool_parser(
- config.actor_rollout_ref.rollout.multi_turn.format, self.tokenizer
- )
- self.tool_parser_name = config.actor_rollout_ref.rollout.multi_turn.format
+ self.tool_parser = ToolParser.get_tool_parser(self.rollout_config.multi_turn.format, self.tokenizer)
+ self.tool_parser_name = self.rollout_config.multi_turn.format
- self.prompt_length = config.actor_rollout_ref.rollout.prompt_length
- self.response_length = config.actor_rollout_ref.rollout.response_length
+ self.prompt_length = self.rollout_config.prompt_length
+ self.response_length = self.rollout_config.response_length
# Initialize interactions from config file
- self.interaction_config_file = config.actor_rollout_ref.rollout.multi_turn.interaction_config_path
+ self.interaction_config_file = self.rollout_config.multi_turn.interaction_config_path
if self.interaction_config_file:
self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(
self.interaction_config_file
@@ -203,6 +192,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
else None,
num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
metrics=agent_data.metrics,
+ routed_experts=agent_data.routed_experts,
extra_fields={},
)
output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards})
@@ -350,10 +340,14 @@ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentSt
None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False)
)
else:
+ # Note that we have to pass None to the images and videos if there are no new images / videos
+ # to stay compatible with downstream image processing logic!
+ images = new_images_this_turn if new_images_this_turn else None
+ videos = None
response_ids = await self.apply_chat_template(
add_messages,
- images=new_images_this_turn, # Using local variable
- videos=None,
+ images=images,
+ videos=videos,
remove_system_prompt=True,
)
diff --git a/verl/experimental/fully_async_policy/README.md b/verl/experimental/fully_async_policy/README.md
index a24e7610102..b7ff1756459 100644
--- a/verl/experimental/fully_async_policy/README.md
+++ b/verl/experimental/fully_async_policy/README.md
@@ -105,9 +105,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev
| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization |
| `async_training.staleness_threshold` | Freshness control |
| `async_training.partial_rollout` | Whether to perform partial_rollout |
-| `async_training.checkpoint_engine.enable` | Whether to use checkpoint_engine for accelerating, default `True` |
-| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | When use checkpoint_engine, whether to overlap broadcast and load_weights, default `False` |
-| `async_training.checkpoint_engine.device_buffer_size_M` | When use checkpoint_engine, the user-specific bucket size (MB), default `4096` |
| `async_training.use_trainer_do_validate` | Whether use trainer node to do validate process, default `False` |
**Further Explanation:**
@@ -181,27 +178,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev
mode d
(async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`.
-* `async_training.checkpoint_engine.enable`
-
- Enabling the checkpoint engine generally reduces synchronization time overhead by more than 60% compared to
- the original per-tensor parameter synchronization method. However, assembling buckets incurs additional
- temporary GPU memory overhead.
-
-* `async_training.checkpoint_engine.overlap_broadcast_and_consume`
-
- Enabling pipeline between the broadcast and load_weights parameters will allocate additional GPU memory.
- Since the main time consumption for parameter synchronization is not in the broadcast and load_weights phases,
- but in the parameter generation phase (by megatron or FSDP), this option is off by default.
-
-* `async_training.checkpoint_engine.device_buffer_size_M`
-
- It controls the size of the memory buffer used for synchronization when the checkpoint-engine is enabled.
- The actual `bucket_size` = `max(device_buffer_size_M, maximum parameter tensor size)`.
- * When enable `overlap_broadcast_and_consume`, the additional device memory overhead of
- trainer rank is `3 * bucket_size`and rollout rank is `2 * bucket_size`。
- * When disable `overlap_broadcast_and_consume`, the additional device memory overhead of
- trainer rank is `2 * bucket_size`and rollout rank is `1 * bucket_size`。
-
* `async_training.use_trainer_do_validate`
It controls whether to use the trainer's `do_validate` method for validation.
diff --git a/verl/experimental/fully_async_policy/README_zh.md b/verl/experimental/fully_async_policy/README_zh.md
index 19a257247c3..ad2e52e4167 100644
--- a/verl/experimental/fully_async_policy/README_zh.md
+++ b/verl/experimental/fully_async_policy/README_zh.md
@@ -82,9 +82,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev
| `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 |
| `async_training.staleness_threshold` | 新鲜度控制 |
| `async_training.partial_rollout` | 是否进行partial_rollout |
-| `async_training.checkpoint_engine.enable` | 是否开启checkpoint_engine模式的加速,默认值True |
-| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | 启动checkpoint_engine时,是否在参数同步时在broadcast和加载之间使用流水,默认值False |
-| `async_training.checkpoint_engine.device_buffer_size_M` | 启动checkpoint_engine时,组装的bucket的大小(MB),默认为4096 |
| `async_training.use_trainer_do_validate` | 是否使用Trainer的do_validate方法进行validation,默认值False |
**进一步的解释:**
@@ -146,20 +143,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev
在实际测试中,我们发现,如果单次下发的样本较少,由于数据分发的顺序,会导致训练不稳定,response 长度变长。
在这里,我们额外提供 require_batches 进行流式分发,单次参与训练的样本数量控制。
-* `async_training.checkpoint_engine.enable`
-
- 开启checkpoint engine后,相较于原始的逐tensor的参数同步方式,同步时间开销普遍可以降低60%以上。但是组装bucket会带来额外的临时显存开销。
-
-* `async_training.checkpoint_engine.overlap_broadcast_and_consume`
-
- 开启参数broadcast和load_weights之间的流水后,会进一步额外申请更多显存。由于目前分析参数同步的主要耗时并非来自broadcast和load_weights阶段,而是在参数生成阶段(由megatron或FSDP),因此该开关默认关闭。
-
-* `async_training.checkpoint_engine.device_buffer_size_M`
-
- 控制开启checkpoint engine后,用于同步的显存buffer大小。实际的`bucket_size` = `max(device_buffer_size_M, 最大参数tensor size)`
- * 在开启`overlap_broadcast_and_consume`时,trainer节点的临时额外显存开销为 `3 * bucket_size`, rollout节点的临时额外显存开销为`2 * bucket_size`。
- * 在关闭`overlap_broadcast_and_consume`时,trainer节点的临时额外显存开销为 `2 * bucket_size`, rollout节点的临时额外显存开销为`1 * bucket_size`。
-
* `async_training.use_trainer_do_validate`
控制是否使用trainer的`do_validate`方法进行validation。
diff --git a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py
index 3098e48ba96..88a012224eb 100644
--- a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py
+++ b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py
@@ -28,11 +28,11 @@
AsyncLLMServerManager,
DictConfigWrap,
_agent_loop_registry,
+ _get_rollout_and_model_config,
get_trajectory_info,
)
-from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config
from verl.protocol import DataProto
-from verl.single_controller.ray import RayWorkerGroup
+from verl.single_controller.ray import RayResourcePool, RayWorkerGroup
from verl.utils.rollout_trace import (
rollout_trace_attr,
rollout_trace_op,
@@ -102,7 +102,7 @@ async def generate_sequences_no_post(
Returns:
list[AgentLoopOutput]: List of agent loop outputs, one per sample in the batch.
"""
- config = self.config.actor_rollout_ref.rollout
+ config = self.rollout_config
sampling_params = dict(
temperature=config.temperature,
top_p=config.top_p,
@@ -191,7 +191,7 @@ async def _partial_run_agent_loop(
tokenizer=self.tokenizer,
processor=self.processor,
dataset_cls=self.dataset_cls,
- dataset_config=DictConfigWrap(config=self.config.data),
+ data_config=DictConfigWrap(config=self.config.data),
)
output: AgentLoopOutput = await agent_loop.run(
sampling_params, cancellation_event=self.cancellation_event, **kwargs
@@ -219,15 +219,17 @@ def __init__(
self,
config: DictConfig,
worker_group: RayWorkerGroup = None,
+ rollout_resource_pool: RayResourcePool = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
self.config = config
+ self.rollout_config, self.model_config = _get_rollout_and_model_config(config)
self.worker_group = worker_group
self.reward_loop_worker_handles = reward_loop_worker_handles
self.agent_loop_workers_class = FullyAsyncAgentLoopWorker
# Select rollout replica class based on rollout name
- rollout_name = config.actor_rollout_ref.rollout.name
+ rollout_name = self.rollout_config.name
if rollout_name == "sglang":
from verl.experimental.fully_async_policy.sglang_rollout.sglang_async_server import FullyAsyncSGLangReplica
@@ -246,63 +248,6 @@ def __init__(
self.server_addresses = None
self.agent_loop_workers = None
- @classmethod
- async def create(
- cls,
- config: DictConfig,
- worker_group: RayWorkerGroup = None,
- reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
- ):
- instance = cls(config, worker_group, reward_loop_worker_handles)
- await instance._async_init()
- return instance
-
- async def _async_init(self):
- await self._initialize_llm_servers_async()
- self._init_agent_loop_workers()
-
- async def _initialize_llm_servers_async(self):
- rollout_world_size = (
- self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
- * self.config.actor_rollout_ref.rollout.data_parallel_size
- * self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size
- )
- world_size = (
- self.worker_group.world_size
- if self.worker_group
- else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes
- )
- num_replicas = world_size // rollout_world_size
-
- rollout_config = self.config.actor_rollout_ref.rollout
- model_config = self.config.actor_rollout_ref.model
- self.rollout_replicas = [
- self.rollout_replica_class(
- replica_rank=replica_rank,
- config=rollout_config,
- model_config=model_config,
- gpus_per_node=self.config.trainer.n_gpus_per_node,
- )
- for replica_rank in range(num_replicas)
- ]
-
- if self.worker_group:
- await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas])
- else:
- await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas])
-
- self.server_handles = [server._server_handle for server in self.rollout_replicas]
- self.server_addresses = [server._server_address for server in self.rollout_replicas]
-
- print(f"AgentLoopManager: {self.server_addresses}")
- # Update Prometheus configuration with server addresses
- if rollout_config.prometheus.enable:
- if rollout_config.disable_log_stats:
- raise ValueError("PROMETHEUS needs disable_log_stats==False, but it is currently True.")
- await asyncio.to_thread(
- update_prometheus_config, rollout_config.prometheus, self.server_addresses, rollout_config.name
- )
-
async def generate_single_sample_async(
self,
sample: DataProto,
diff --git a/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py
index cbfb4954f4f..6982184f8f6 100644
--- a/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py
+++ b/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py
@@ -30,9 +30,9 @@ class PartialSingleTurnAgentLoop(AgentLoopBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
- self.response_length = self.config.actor_rollout_ref.rollout.response_length
- self.apply_chat_template_kwargs = self.config.data.get("apply_chat_template_kwargs", {})
+ self.prompt_length = self.rollout_config.prompt_length
+ self.response_length = self.rollout_config.response_length
+ self.apply_chat_template_kwargs = self.data_config.get("apply_chat_template_kwargs", {})
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
output: Optional[AgentLoopOutput] = kwargs.get("output", None)
@@ -124,6 +124,8 @@ def get_prompt_ids():
"is_cancel": is_cancel,
"param_version_start": param_version_start,
"param_version_end": param_version_end,
+ "turn_scores": [],
+ "tool_rewards": [],
},
multi_modal_data=multi_modal_data,
# multi_modal_data={"image": image_data} if image_data is not None else {},
diff --git a/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py
index 0082fc13bc8..370587f0364 100644
--- a/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py
+++ b/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py
@@ -33,9 +33,9 @@ class AsyncPartialToolAgentLoop(ToolAgentLoop):
"""
- def __init__(self, trainer_config, **kwargs):
- super().__init__(trainer_config, **kwargs)
- self.enable_partial_rollout = trainer_config.config.async_training.get("partial_rollout", False)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.enable_partial_rollout = self.config.async_training.get("partial_rollout", False)
# async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
async def run(
diff --git a/verl/experimental/fully_async_policy/base_detach_sync.py b/verl/experimental/fully_async_policy/base_detach_sync.py
deleted file mode 100644
index c0924417d78..00000000000
--- a/verl/experimental/fully_async_policy/base_detach_sync.py
+++ /dev/null
@@ -1,238 +0,0 @@
-# Copyright 2025 Bytedance Ltd. and/or its affiliates
-# Copyright 2025 Meituan Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-import logging
-import os
-import threading
-
-import torch
-from omegaconf import DictConfig
-from ray.util.collective import collective
-
-from verl.single_controller.base.decorator import Dispatch, register
-from verl.utils.device import get_torch_device, is_npu_available
-from verl.utils.distributed import stateless_init_process_group
-
-logger = logging.getLogger(__file__)
-logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
-
-
-class BaseDetachNcclSync:
- _bucket_size_mb = 1024.0
- _sync_history = []
- _max_history_size = 20
- _last_avg_bucket_size = 1024.0
-
- def __init__(self, config: DictConfig, role: str):
- self._bg_loop = asyncio.new_event_loop()
- self._bg_thread = threading.Thread(
- target=self._start_background_loop, args=(self._bg_loop,), name="rollout_actor_async_worker", daemon=True
- )
- self._bg_thread.start()
- logger.info(f"[DetachNcclSync] Background thread for SGLang sync started. PID: {os.getpid()}")
-
- @classmethod
- def get_bucket_size_mb(cls):
- return cls._bucket_size_mb
-
- @classmethod
- def get_last_avg_bucket_size(cls):
- return cls._last_avg_bucket_size
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
- def get_last_avg_bucket_size_remote(self):
- return BaseDetachNcclSync._last_avg_bucket_size
-
- @classmethod
- def record_sync_metrics(cls, bucket_size_mb, sync_time):
- """Dynamically adjust the bucket size based on past synchronization times."""
- bucket_size_mb_value = bucket_size_mb[0] if isinstance(bucket_size_mb, list) else bucket_size_mb
- print(f"[DetachNcclSync] sync_metrics: bucket_size_mb={bucket_size_mb_value:.2f}MB, sync_time={sync_time:.2f}s")
- cls._sync_history.append((bucket_size_mb_value, sync_time))
- if len(cls._sync_history) > cls._max_history_size:
- cls._sync_history.pop(0)
-
- MIN_BUCKET_SIZE_MB = 512
- MAX_BUCKET_SIZE_MB = 8192 # 8GB
-
- if len(cls._sync_history) < 4:
- cls._bucket_size_mb = min(MAX_BUCKET_SIZE_MB, cls._bucket_size_mb * 1.5)
- else:
- times = [t for _, t in cls._sync_history]
- buckets = [b for b, _ in cls._sync_history]
- recent_avg_time = sum(times[-2:]) / 2
- previous_avg_time = sum(times[-4:-2]) / 2
- recent_avg_bucket = sum(buckets[-2:]) / 2
- previous_avg_bucket = sum(buckets[-4:-2]) / 2
-
- performance_improved = recent_avg_time < previous_avg_time
- bucket_increased = recent_avg_bucket > previous_avg_bucket
- time_change_ratio = (
- abs(recent_avg_time - previous_avg_time) / previous_avg_time if previous_avg_time > 0 else 0.0
- )
-
- if time_change_ratio > 0.2:
- increase_step, decrease_step = 1.2, 0.8
- elif time_change_ratio > 0.1:
- increase_step, decrease_step = 1.1, 0.9
- elif time_change_ratio > 0.05:
- increase_step, decrease_step = 1.05, 0.95
- else:
- increase_step, decrease_step = 1.02, 0.98
-
- should_increase = (performance_improved and bucket_increased) or (
- not performance_improved and not bucket_increased
- )
- step = increase_step if should_increase else decrease_step
- new_size = cls._bucket_size_mb * step
- cls._bucket_size_mb = min(MAX_BUCKET_SIZE_MB, max(MIN_BUCKET_SIZE_MB, new_size))
-
- def _start_background_loop(self, loop):
- asyncio.set_event_loop(loop)
- try:
- loop.run_forever()
- except Exception as e:
- logger.error(f"[DetachNcclSync] Background loop crashed: {e}")
-
- def _run_async_safely(self, coro):
- if not self._bg_thread.is_alive():
- raise RuntimeError("Background thread for SGLang sync is not running!")
-
- future = asyncio.run_coroutine_threadsafe(coro, self._bg_loop)
- return future.result()
-
- def __del__(self):
- if hasattr(self, "_bg_loop") and self._bg_loop.is_running():
- self._bg_loop.call_soon_threadsafe(self._bg_loop.stop)
- if hasattr(self, "_bg_thread") and self._bg_thread.is_alive():
- self._bg_thread.join(timeout=1.0)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def init_checkpoint_engine(self, rank_offset: int, actor_num: int, rollout_num: int):
- from .checkpoint_engine import CheckpointEngine
-
- current_rank = torch.distributed.get_rank() + rank_offset
- actor_ranks = list(range(actor_num))
- rollout_ranks = [rank + actor_num for rank in range(rollout_num)]
- assert rank_offset == 0 or rank_offset == actor_num
-
- self.checkpoint_engine = CheckpointEngine(
- current_rank, actor_ranks, rollout_ranks, self.config.checkpoint_engine.device_buffer_size_M
- )
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size):
- rank = torch.distributed.get_rank() + rank_offset
- self._weight_sync_group = stateless_init_process_group(
- master_address,
- master_port,
- rank,
- world_size,
- get_torch_device().current_device(),
- )
-
- @staticmethod
- def get_inference_model(rollout):
- """
- Get models according to different types of inference_engine
- Args:
- rollout: rollout object
- Returns:
- model: model object (for vllm) or rollout object itself (for sglang)
- """
- inference_engine = rollout.inference_engine
- if hasattr(inference_engine, "llm_engine"):
- inference_model = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
- elif hasattr(inference_engine, "worker"):
- inference_model = inference_engine.worker.model_runner.model
- else:
- raise AttributeError(
- f"Unsupported inference_engine type: {type(inference_engine)}. "
- f"Expected LLM (with llm_engine attribute) or WorkerWrapperBase (with worker attribute)."
- )
- return inference_model
-
- def _sync_sglang_weights(self, inference_model, params, sync_group_name):
- bucket_size_bytes = int(self.get_bucket_size_mb() * 1024 * 1024)
- actual_bucket_sizes = []
- current_batch = []
- current_batch_size = 0
-
- def flush_batch():
- if current_batch:
- actual_bucket_sizes.append(current_batch_size / (1024 * 1024))
- self._run_async_safely(self.update_weights(inference_model, iter(current_batch)))
- get_torch_device().synchronize()
- current_batch.clear()
-
- for key, shape, dtype in self._weights_info:
- tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
- if self._is_actor:
- assert key in params
- origin_data = params[key]
- if hasattr(origin_data, "full_tensor"):
- origin_data = origin_data.full_tensor()
- if torch.distributed.get_rank() == 0:
- tensor.copy_(origin_data)
- collective.broadcast(tensor, src_rank=0, group_name=sync_group_name)
-
- tensor_size = tensor.numel() * tensor.element_size()
- current_batch.append((key, tensor))
- current_batch_size += tensor_size
-
- if current_batch_size >= bucket_size_bytes:
- flush_batch()
- current_batch_size = 0
-
- flush_batch()
- cls = type(self)
- cls._last_avg_bucket_size = (
- sum(actual_bucket_sizes) / len(actual_bucket_sizes) if actual_bucket_sizes else self.get_bucket_size_mb()
- )
-
- # Resume kv_cache after weights sync to restore GPU memory released during pause
- if self._is_rollout and self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
- self._run_async_safely(inference_model.resume_memory_occupation(tags=["kv_cache"]))
-
- def _sync_vllm_weights(self, inference_model, params, sync_group_name):
- for key, shape, dtype in self._weights_info:
- tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
- if self._is_actor:
- assert key in params
- origin_data = params[key]
- if hasattr(origin_data, "full_tensor"):
- origin_data = origin_data.full_tensor()
- if torch.distributed.get_rank() == 0:
- tensor.copy_(origin_data)
- if is_npu_available:
- self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())
- else:
- collective.broadcast(tensor, src_rank=0, group_name=sync_group_name)
- if self._is_rollout:
- inference_model.load_weights([(key, tensor)])
-
- async def update_weights(self, inference_engine, params):
- from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
-
- await sgl_update_weights(
- engine=inference_engine,
- params_batch=params,
- device_mesh_key="infer_tp",
- device_mesh=self.rollout_device_mesh,
- )
-
- if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
- await inference_engine.flush_cache()
diff --git a/verl/experimental/fully_async_policy/checkpoint_engine.py b/verl/experimental/fully_async_policy/checkpoint_engine.py
deleted file mode 100644
index 28f932d61b3..00000000000
--- a/verl/experimental/fully_async_policy/checkpoint_engine.py
+++ /dev/null
@@ -1,522 +0,0 @@
-# Copyright 2025 Meituan Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-This logic is largely copied from:
-- https://github.com/MoonshotAI/checkpoint-engine
-"""
-
-import concurrent.futures
-import os
-import re
-import socket
-import subprocess
-import threading
-from collections.abc import Callable
-from functools import lru_cache
-from typing import TYPE_CHECKING, Annotated, Any, TypedDict
-
-import torch
-import zmq
-from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
-from ray.util.collective import collective
-
-from verl.utils.device import (
- get_device_name,
- get_torch_device,
-)
-
-if TYPE_CHECKING:
- from typing import TypeVar
-
- from typing_extensions import TypedDict
-
- class FileMeta(TypedDict):
- key: str # parameter name
- dtype: torch.dtype
- shape: torch.Size
- type: type
- tp_concat_dim: int
-
- T = TypeVar("T")
-
-
-def _dt_validate(value: Any) -> torch.dtype:
- """Validate the input value to ensure it is a valid torch.dtype"""
- if isinstance(value, str):
- if not value.startswith("torch."):
- raise ValueError(f"dtype {value} should start with torch.")
- try:
- value = getattr(torch, value.split(".")[1])
- except AttributeError as e:
- raise ValueError(f"unknown dtype: {value}") from e
- if not isinstance(value, torch.dtype):
- raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
- return value
-
-
-# Annotated type for torch.dtype with validation and serialization
-_TorchDtype = Annotated[
- torch.dtype,
- PlainValidator(_dt_validate),
- PlainSerializer(lambda x: str(x), return_type=str),
- WithJsonSchema({"type": "string"}, mode="serialization"),
-]
-
-
-def _size_validate(value: Any) -> torch.Size:
- """Validate the input value to ensure it is a valid torch.Size"""
- if isinstance(value, list | tuple):
- return torch.Size(value)
- if not isinstance(value, torch.Size):
- raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
- return value
-
-
-# Annotated type for torch.Size with validation and serialization
-_TorchSize = Annotated[
- torch.Size,
- PlainValidator(_size_validate),
- PlainSerializer(lambda x: tuple(x), return_type=tuple),
- WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
-]
-
-
-def _tensor_validate(value: Any) -> torch.Tensor:
- """Validate the input value to ensure it is a valid torch.Tensor"""
- if isinstance(value, torch.Tensor):
- return value
- raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
-
-
-# Annotated type for torch.Tensor with validation
-_TorchTensor = Annotated[
- torch.Tensor,
- PlainValidator(_tensor_validate),
-]
-
-
-class ParameterMeta(BaseModel):
- """Metadata for a parameter including name, dtype, and shape"""
-
- name: str
- dtype: _TorchDtype
- shape: _TorchSize
-
-
-class MemoryBuffer(BaseModel):
- """
- MemoryBuffer assembles a group of parameter tensors into a single buffer,
- and records the meta information of each original parameter.
- """
-
- buffer: _TorchTensor
- size: int # size of buffer in bytes
- metas: list[ParameterMeta]
-
-
-class MemoryBufferMeta(BaseModel):
- """The meta info of MemoryBuffer, but not store the buffer data"""
-
- size: int
- metas: list[ParameterMeta]
-
-
-# 256 bytes alignment when flatten torch tensors to uint8 buffer
-_ALIGN_SIZE = 256
-
-
-def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
- """
- Calculate the aligned size of a torch tensor
-
- If the tensor's size (in bytes) cannot be evenly divided by _ALIGN_SIZE,
- it will be rounded up to the nearest multiple of _ALIGN_SIZE.
-
- Args:
- dtype (torch.dtype): The data type of the tensor (e.g., torch.float32, torch.int64).
- shape (torch.Size): The shape of the tensor, representing its dimensions.
-
- Returns:
- int: The aligned size of the tensor in bytes.
- """
- return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
-
-
-@lru_cache(maxsize=1)
-def get_ip() -> str:
- try:
- # try to get ip from network interface
- with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(("8.8.8.8", 80))
- return s.getsockname()[0]
- except Exception as e: # noqa: BLE001
- # fallback to get ip from hostname
- print(f"fail to get ip from network interface, fallback to get ip from hostname: {e}")
- return socket.gethostbyname(socket.gethostname())
-
-
-def npu_generate_uuid() -> str:
- """Generate uuid for each npu device"""
- str_pid = str(os.getpid())
- npu_num = 8
- try:
- for npu_id in range(npu_num):
- cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
- result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
- str_result = str(result.stdout)
- if str_pid in str_result:
- # In A3 server, one NPU has two chips.
- match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
- chip_count = int(match_chip_count.group(1))
- search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
- match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
- chip_id = int(match_chip_id.group(1))
- return f"{get_ip()}-{npu_id * chip_count + chip_id}"
- raise ValueError("The current process is not running on the npu device")
- except subprocess.CalledProcessError as e:
- raise ValueError("The current process is not running on the npu device") from e
-
-
-def _get_physical_device_id(device_index: int | None = None) -> str:
- """
- Get the physical device (GPU or NPU) uuid of the current device
- """
- try:
- if get_device_name() == "npu":
- return f"NPU-{npu_generate_uuid()}"
- else:
- return f"GPU-{get_torch_device().get_device_properties(device_index).uuid!s}"
- except AssertionError as e:
- raise ValueError(f"fail to get physical gpu id {device_index}") from e
-
-
-class FlattenedTensorMetadata(TypedDict):
- name: str
- shape: torch.Size
- dtype: torch.dtype
- # specify the start offset of this tensor in shared ipc_buffer tensor
- offset: int
-
-
-def _to_flattened_tensor_meta(metas: list[ParameterMeta], offset: int = 0) -> list[FlattenedTensorMetadata]:
- """
- compute the offset of each parameter in the buffer
-
- Args:
- metas (list[ParameterMeta]): The list of parameter metas info
- offset (int): The start offset of the buffer. Defaults to 0.
-
- Returns:
- list[FlattenedTensorMetadata]: The list of FlattenedTensorMetadata:
- """
- ret = []
- for meta in metas:
- size = _align_size(meta.dtype, meta.shape)
- ret.append(
- {
- "name": meta.name,
- "dtype": meta.dtype,
- "shape": meta.shape,
- "offset": offset,
- }
- )
- offset += size
- return ret
-
-
-def _extract_weights(
- flatten_metas: list[FlattenedTensorMetadata], buffer: torch.Tensor
-) -> list[tuple[str, torch.Tensor]]:
- """
- According to the flatten_metas and buffer, extract the weights
- """
-
- assert buffer is not None
- weights: list[tuple[str, torch.Tensor]] = []
- for item in flatten_metas:
- shape = item["shape"]
- if isinstance(shape, list | tuple):
- shape = torch.Size(shape)
- assert isinstance(shape, torch.Size)
- dtype, offset = item["dtype"], item["offset"]
- size = dtype.itemsize * shape.numel()
- tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
- weights.append((item["name"], tensor))
- return weights
-
-
-class CheckpointEngine:
- """
- CheckpointEngine class for control parameters synchronization.
- Each trainer/rollout rank has a CheckpointEngine instance.
- """
-
- def __init__(
- self, current_rank: int, actor_ranks: list[int], rollout_ranks: list[int], device_buffer_size_M: int
- ) -> None:
- self.current_rank = current_rank
- self.actor_ranks = actor_ranks
- self.rollout_ranks = rollout_ranks
- # global_buckets saves the global MemoryBufferMeta infos.
- # Thus each CheckpointEngine instance can control their operations in SPMD
- self.global_buckets: dict[int, list[MemoryBufferMeta]] = None
- # min device_buffer_size for h2d and broadcast
- self.device_buffer_size_M = device_buffer_size_M
-
- # ipc config for broadcast in pipeline mode
- self._zmq_ctx = zmq.Context()
- self._zmq_addr_counter: int = 0
- device_index = self.current_rank % get_torch_device().device_count()
- self._device_uuid = _get_physical_device_id(device_index)
-
- def register_checkpoint(
- self, weights_info: list[tuple[str, torch.Size, torch.dtype]], cpu_named_params: dict[str, torch.Tensor]
- ):
- """
- Register checkpoint information and prepare memory buffers for parameter synchronization.
-
- This function organizes the parameters into memory buckets for efficient synchronization
- and prepares pinned memory buffers for faster data transfer between CPU and device.
-
- Args:
- weights_info (list[tuple[str, torch.Size, torch.dtype]]):
- A list of tuples containing parameter name, shape, and data type.
- cpu_named_params (dict[str, torch.Tensor]):
- A dictionary mapping parameter names to their corresponding CPU tensors.
-
- Steps:
- 1. Calculate the bucket size based on the largest parameter tensor size and the device buffer size.
- 2. Organize parameters into global buckets for each actor rank, ensuring that the total size of each bucket
- does not exceed the bucket size.
- 3. For actor ranks, allocate pinned memory buffers for each bucket and copy the parameter tensors
- into these buffers.
-
- Notes:
- Each CheckpointEngine instance maintains the global buckets metas,
- but stores part of parmas data in host memory
- """
- bucket_size = max(
- self.device_buffer_size_M << 20, max(_align_size(dtype, shape) for _, shape, dtype in weights_info)
- )
- print(
- f"set checkpoint_engine device buffer size: {self.device_buffer_size_M}M, "
- f"and finally set it to {bucket_size >> 20}M considering the largest parameter tensor size"
- )
- self.bucket_size = bucket_size
-
- # global_buckets saves the global MemoryBufferMeta infos.
- if self.global_buckets is None:
- self.global_buckets = {rank: [MemoryBufferMeta(size=0, metas=[])] for rank in self.actor_ranks}
-
- actor_ranks_size = len(self.actor_ranks)
- assert actor_ranks_size > 0, f"actor_ranks:{self.actor_ranks} should not be empty"
- for param_idx, (param_name, param_shape, param_dtype) in enumerate(weights_info):
- # Each parameter is assigned to an actor rank, and only this rank will store it
- assgin_rank = self.actor_ranks[param_idx % actor_ranks_size]
- param_size = _align_size(param_dtype, param_shape)
-
- if self.global_buckets[assgin_rank][-1].size + param_size > bucket_size:
- assert self.global_buckets[assgin_rank][-1].size, (
- f"global_buckets[{assgin_rank}][-1].size:{self.global_buckets[assgin_rank][-1].size}"
- " should not be 0"
- )
- self.global_buckets[assgin_rank].append(MemoryBufferMeta(size=0, metas=[]))
- self.global_buckets[assgin_rank][-1].metas.append(
- ParameterMeta(name=param_name, dtype=param_dtype, shape=param_shape)
- )
- self.global_buckets[assgin_rank][-1].size += param_size
-
- def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
- """Allocate pinned memory for a bucket."""
- buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
- return idx, buffer
-
- def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
- """Copy a tensor into a pinned memory buffer."""
- buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
-
- memory_buffers = [] # for rollout rank, return empty buffer
- if self.current_rank in self.actor_ranks: # is_actor
- local_buckets = self.global_buckets[self.current_rank]
- memory_buffers = [
- MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) for bucket in local_buckets
- ]
-
- # Use thread pool to accelerate organize parameters into buckets
- with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
- futures = [
- executor.submit(register_pin_memory, idx, bucket.size) for idx, bucket in enumerate(local_buckets)
- ]
- new_futures = []
- for future in concurrent.futures.as_completed(futures):
- idx, buffer = future.result()
- assert buffer.numel() == local_buckets[idx].size, (
- f"buffer numel {buffer.numel()} should be equal to bucket size {local_buckets[idx].size}"
- )
- memory_buffers[idx].buffer = buffer
- print(
- f"[rank{self.current_rank}] register pin_memory for "
- f" bucket {idx + 1}/{len(local_buckets)} finished, "
- f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
- )
- offset = 0
- for meta in local_buckets[idx].metas:
- name = meta.name
- tensor = cpu_named_params[name]
- size = _align_size(tensor.dtype, tensor.shape)
- assert size == _align_size(meta.dtype, meta.shape), (
- f"tensor {name} size {size} should be equal to "
- f"meta size {_align_size(meta.dtype, meta.shape)}"
- )
- new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
- offset += size
- for future in concurrent.futures.as_completed(new_futures):
- future.result()
-
- self.memory_buffers = memory_buffers
-
- def get_max_buckets_num_per_rank(self):
- """
- Get the maximum number of buckets for all rank.
- """
- assert self.global_buckets is not None
- return max(len(buckets) for buckets in self.global_buckets.values())
-
- def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
- """
- Bind zmq socket for broadcast.
- """
-
- def zmq_handle(device_uuid: str) -> str:
- return f"ipc://@checkpoint-engine-{device_uuid}-{self._zmq_addr_counter}.sock"
-
- socket_path = zmq_handle(self._device_uuid)
- socket = self._zmq_ctx.socket(zmq.REQ)
- socket.bind(socket_path)
- self._zmq_addr_counter += 1
- return socket, socket_path
-
- def update_checkpoint(self, inference_model, group_name: str, overlap_broadcast_and_consume: bool = False):
- """
- Update the checkpoint by broadcasting and loading weights.
-
- This function handles the synchronization of parameters across ranks by:
- 1. Copying data from memory buffers to device buffers (h2d_buffer).
- 2. Broadcasting the data to all ranks using collective communication.
- 3. Loading the weights into the inference model if provided.
- 4. Optionally, use a pipeline approach for broadcasting and loading weights.
-
- Args:
- inference_model: The model to load weights into. If None (trainer rank), weights are only broadcasted.
- group_name (str): The name of the collective communication group.
- overlap_broadcast_and_consume (bool): Whether to use the pipeline approach
- for broadcasting and loading weights.
- """
- try:
- h2d_buffer: torch.Tensor | None = (
- None
- if self.current_rank in self.rollout_ranks
- else torch.empty(self.bucket_size, dtype=torch.uint8, device=get_torch_device().current_device())
- )
- # for pipeline mode, we need to allocate 2x buffer size
- broadcast_load_buffer = torch.empty(
- self.bucket_size * (2 if overlap_broadcast_and_consume else 1),
- dtype=torch.uint8,
- device=get_torch_device().current_device(),
- )
- except Exception:
- print(
- "allocate buffer for update_checkpoint failed, "
- "you may need to reduce "
- "config.async_training.checkpoint_engine.device_buffer_size_M"
- )
- raise
-
- max_h2d_iter = self.get_max_buckets_num_per_rank()
-
- if overlap_broadcast_and_consume:
- socket, socket_path = self._bind_zmq_socket()
-
- # Define a function to update weights from IPC
- def update_weights_from_ipc_(socket_path):
- zmq_ctx = zmq.Context()
- socket = zmq_ctx.socket(zmq.REP)
- socket.connect(socket_path)
- socket.recv_pyobj()
- socket.send(b"")
-
- while True:
- payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
- if payload is None:
- # means the update is done
- get_torch_device().synchronize()
- socket.send(b"")
- break
- assert isinstance(payload, list)
- if inference_model is not None:
- inference_model.load_weights(_extract_weights(payload, broadcast_load_buffer))
- get_torch_device().synchronize()
- socket.send(b"")
-
- req_thread = threading.Thread(
- target=update_weights_from_ipc_,
- args=(socket_path,),
- )
- req_thread.start()
- socket.send_pyobj(b"")
- get_torch_device().synchronize()
-
- gidx = 0
- local_buckets = self.global_buckets.get(self.current_rank, [])
-
- for i in range(max_h2d_iter):
- # Step 1: Each actor rank copy the parameter tensor into device memory
- if i < len(self.memory_buffers):
- h2d_buffer[: local_buckets[i].size].data.copy_(self.memory_buffers[i].buffer)
-
- # Step 2: Broadcast the device data in turn
- for broadcast_rank, _buckets in self.global_buckets.items():
- if i >= len(_buckets):
- continue
- bucket = _buckets[i]
-
- # Prepare the broadcast buffer
- start = gidx % 2 * self.bucket_size if overlap_broadcast_and_consume else 0
- buffer_b: torch.Tensor = broadcast_load_buffer[start : start + bucket.size]
- if broadcast_rank == self.current_rank:
- buffer_b.data.copy_(h2d_buffer[: bucket.size])
-
- # Broadcast the buffer to all ranks
- collective.broadcast(buffer_b, src_rank=broadcast_rank, group_name=group_name)
-
- if overlap_broadcast_and_consume:
- socket.recv()
- collective.barrier(group_name=group_name)
- socket.send_pyobj(_to_flattened_tensor_meta(bucket.metas, start))
- elif inference_model is not None:
- named_tensor = _to_flattened_tensor_meta(bucket.metas, 0)
- inference_model.load_weights(_extract_weights(named_tensor, buffer_b))
-
- gidx += 1
-
- if overlap_broadcast_and_consume:
- socket.recv()
- socket.send_pyobj(None)
- socket.recv()
- req_thread.join()
- socket.close()
-
- collective.barrier(group_name=group_name)
- # clear host memory cache
- self.memory_buffers = []
diff --git a/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml b/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml
index eece540865c..5ef1bc6c813 100644
--- a/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml
+++ b/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml
@@ -6,6 +6,9 @@ defaults:
- ppo_megatron_trainer
- _self_
+trainer:
+ use_legacy_worker_impl: disable
+
async_training:
# Maximum samples staleness threshold
@@ -25,17 +28,6 @@ async_training:
use_trainer_do_validate: False
- # checkpoint_engine config for accelerating parameter synchronization between rollouter and trainer
- checkpoint_engine:
- # Whether to use checkpoint_engine
- enable: True
-
- # Device buffer size for checkpoint_engine, default is 4096 MB
- device_buffer_size_M: 4096
-
- # Enable the pipeline for broadcasting and updating parameters, but it requires more device memory
- overlap_broadcast_and_consume: False
-
# Rollout config
rollout:
@@ -62,17 +54,15 @@ data:
gen_batch_size: 1
actor_rollout_ref:
- # checkpoint_engine config for accelerating parameter synchronization between rollouter and trainer
- checkpoint_engine: ${oc.select:async_training.checkpoint_engine, null}
rollout:
# Must be turned off! Otherwise, Parameter synchronization cannot be performed.
free_cache_engine: False
# Must be enabled! Otherwise, log_probs cannot be calculated.
calculate_log_probs: True
- # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced.
- # TODO: Can be removed in the future once parameter synchronization is ready.
- load_format: "auto"
+
+ checkpoint_engine:
+ backend: "nccl"
actor:
# Must use rollout log probs for training
diff --git a/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml b/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml
index 7dece1cd479..1f4b4db8c82 100644
--- a/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml
+++ b/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml
@@ -6,6 +6,9 @@ defaults:
- ppo_trainer
- _self_
+trainer:
+ use_legacy_worker_impl: disable
+
async_training:
# Maximum samples staleness threshold
@@ -24,18 +27,6 @@ async_training:
# whether to use trainer do_validate
use_trainer_do_validate: False
-
- # checkpoint_engine config for accelerating parameter synchronization between rollouter and trainer
- checkpoint_engine:
- # Whether to use checkpoint_engine
- enable: True
-
- # Device buffer size for checkpoint_engine, default is 4096 MB
- device_buffer_size_M: 4096
-
- # Enable the pipeline for broadcasting and updating parameters, but it requires more device memory
- overlap_broadcast_and_consume: False
-
# Rollout config
rollout:
@@ -62,17 +53,15 @@ data:
gen_batch_size: 1
actor_rollout_ref:
- # checkpoint_engine config for accelerating parameter synchronization between rollouter and trainer
- checkpoint_engine: ${oc.select:async_training.checkpoint_engine, null}
rollout:
# Must be turned off! Otherwise, Parameter synchronization cannot be performed.
free_cache_engine: False
# Must be enabled! Otherwise, log_probs cannot be calculated.
calculate_log_probs: True
- # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced.
- # TODO: Can be removed in the future once parameter synchronization is ready.
- load_format: "auto"
+
+ checkpoint_engine:
+ backend: "nccl"
actor:
# Must use rollout log probs for training
@@ -82,4 +71,4 @@ actor_rollout_ref:
# And it can be used in conjunction with other rollout_correction algorithms.
algorithm:
rollout_correction:
- bypass_mode: True
\ No newline at end of file
+ bypass_mode: True
diff --git a/verl/experimental/fully_async_policy/fsdp_workers.py b/verl/experimental/fully_async_policy/fsdp_workers.py
deleted file mode 100644
index 227ae06307d..00000000000
--- a/verl/experimental/fully_async_policy/fsdp_workers.py
+++ /dev/null
@@ -1,263 +0,0 @@
-# Copyright 2025 Bytedance Ltd. and/or its affiliates
-# Copyright 2025 Meituan Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-import time
-
-import torch
-import torch.distributed
-from omegaconf import DictConfig
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-
-from verl.experimental.fully_async_policy.base_detach_sync import BaseDetachNcclSync
-from verl.experimental.fully_async_policy.fsdp2_utils import fsdp2_sharded_load_from_cpu, fsdp2_sharded_save_to_cpu
-from verl.single_controller.base.decorator import Dispatch, register
-from verl.utils.device import (
- get_device_name,
- get_torch_device,
-)
-from verl.utils.fsdp_utils import (
- fsdp_version,
- load_fsdp_model_to_gpu,
- offload_fsdp_model_to_cpu,
-)
-from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker, CriticWorker
-
-logger = logging.getLogger(__file__)
-logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
-
-device_name = get_device_name()
-
-__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"]
-
-
-class DetachNcclSync(BaseDetachNcclSync, AsyncActorRolloutRefWorker):
- def __init__(self, config: DictConfig, role: str):
- BaseDetachNcclSync.__init__(self, config, role)
- AsyncActorRolloutRefWorker.__init__(self, config, role)
-
- def _get_actor_params(self):
- pass
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def sync_rollout_weights(self, sync_group_name="actor_rollout"):
- assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
- assert hasattr(self, "_weights_info") and self._weights_info is not None
-
- if self._is_actor and self._is_offload_param:
- load_fsdp_model_to_gpu(self.actor_module_fsdp)
- params = self._get_actor_params() if self._is_actor else None
- rollout_name = self.config.rollout.name
-
- inference_model = None
- if self._is_rollout:
- if rollout_name == "vllm":
- inference_model = BaseDetachNcclSync.get_inference_model(self.rollout)
-
- from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
-
- patch_vllm_moe_model_weight_loader(inference_model)
- elif rollout_name == "sglang":
- inference_model = self.rollout._engine
- # For ServerAdapter, _engine might be None and needs async initialization
- if inference_model is None:
- # Initialize the server adapter engine
- print("[sync_rollout_weights] Initialize server adapter engine")
-
- async def init_engine():
- if hasattr(self.rollout, "_init_server_adapter"):
- await self.rollout._init_server_adapter()
- else:
- print("[sync_rollout_weights] No _init_server_adapter method found")
- return self.rollout._engine
-
- inference_model = self._run_async_safely(init_engine())
- # For ServerAdapter, only TP rank 0 initializes the engine
- # TP rank != 0 can safely have inference_model as None
- from verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter
-
- is_server_adapter = isinstance(self.rollout, ServerAdapter)
- is_non_tp_rank = False
- if (
- is_server_adapter
- and hasattr(self.rollout, "device_mesh")
- and self.rollout.device_mesh is not None
- ):
- try:
- is_non_tp_rank = self.rollout.device_mesh["infer_tp"].get_local_rank() != 0
- except Exception:
- pass
-
- if inference_model is None and not (is_server_adapter and is_non_tp_rank):
- raise RuntimeError(
- f"Failed to initialize rollout engine. "
- f"rollout type: {type(self.rollout)}, "
- f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}"
- )
- else:
- raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
-
- if rollout_name == "sglang" and self._is_rollout:
- self._sync_sglang_weights(inference_model, params, sync_group_name)
- else:
- self._sync_vllm_weights(inference_model, params, sync_group_name)
-
- if self._is_actor and self._is_offload_param:
- offload_fsdp_model_to_cpu(self.actor_module_fsdp)
- get_torch_device().empty_cache()
-
- def cache_actor_weights_to_cpu(self):
- self.cpu_named_params = {}
- if self._is_actor:
- params = self._get_actor_params()
- local_rank = torch.distributed.get_rank()
- world_size = torch.distributed.get_world_size()
-
- for tensor_idx, (key, _, _) in enumerate(self._weights_info):
- origin_data = params[key]
- if hasattr(origin_data, "full_tensor"):
- origin_data = origin_data.full_tensor()
-
- if tensor_idx % world_size == local_rank:
- self.cpu_named_params[key] = origin_data.to("cpu", non_blocking=True)
- get_torch_device().synchronize()
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):
- assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
- assert hasattr(self, "_weights_info") and self._weights_info is not None
-
- # Load model to GPU
- load_start_time = time.time()
- if self._is_actor and self._is_offload_param:
- load_fsdp_model_to_gpu(self.actor_module_fsdp)
- load_duration = time.time() - load_start_time
-
- from ray.util.collective import collective
-
- # Cache actor weights to CPU and measure the time taken
- cache_start_time = time.time()
- self.cache_actor_weights_to_cpu()
- cache_end_time = time.time()
- cache_duration = cache_end_time - cache_start_time
-
- # Register the cached weights into the checkpoint engine
- self.checkpoint_engine.register_checkpoint(self._weights_info, self.cpu_named_params)
- register_end_time = time.time()
- register_duration = register_end_time - cache_end_time
- self.cpu_named_params = {}
-
- collective.barrier(group_name=sync_group_name)
- update_start_time = time.time()
-
- inference_model = None
- if self._is_rollout:
- inference_model = BaseDetachNcclSync.get_inference_model(self.rollout)
- from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
-
- patch_vllm_moe_model_weight_loader(inference_model)
-
- # Update the checkpoint with the inference model and broadcast weights
- self.checkpoint_engine.update_checkpoint(
- inference_model=inference_model,
- group_name=sync_group_name,
- overlap_broadcast_and_consume=self.config.checkpoint_engine.overlap_broadcast_and_consume,
- )
-
- update_end_time = time.time()
- update_duration = update_end_time - update_start_time
-
- offload_start_time = time.time()
- if self._is_actor and self._is_offload_param:
- offload_fsdp_model_to_cpu(self.actor_module_fsdp)
- offload_duration = time.time() - offload_start_time
-
- print(
- f"sync_rollout_weights_by_checkpoint finish!, rank:{torch.distributed.get_rank()},"
- f" is_actor:{self._is_actor}, is_rollout:{self._is_rollout},"
- f" total cost:{update_end_time - cache_start_time} seconds, while cache cost {cache_duration} seconds, "
- f" register cost {register_duration} seconds, update cost {update_duration} seconds"
- )
-
- if self._is_actor and self._is_offload_param:
- print(
- f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds,"
- f" offload model to cpu cost {offload_duration} seconds"
- )
-
-
-class DetachActorWorker(DetachNcclSync):
- def __init__(self, config: DictConfig, role: str):
- print("[DetachAsyncRolloutWorker] Initializing via DetachNcclSync...")
- DetachNcclSync.__init__(self, config, role)
-
- def _get_actor_params(self):
- assert self._is_actor
- params = self.actor_module_fsdp.state_dict()
- from verl.utils.model import convert_weight_keys
-
- params = convert_weight_keys(
- params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
- )
- return params
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def get_actor_weights_info(self):
- assert self._is_actor
- if hasattr(self, "_weights_info"):
- return self._weights_info
- if fsdp_version(self.actor_module_fsdp) == 1:
- from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
-
- FSDP.set_state_dict_type(
- self.actor_module_fsdp,
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
- state_dict_config=ShardedStateDictConfig(),
- )
- params = self._get_actor_params()
- ret = []
- for key, tensor in params.items():
- ret.append((key, tensor.size(), tensor.dtype))
- self._weights_info = ret
- return ret
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def save_model_to_cpu(self, n):
- if not hasattr(self, "cpu_saved_models"):
- self.cpu_saved_models = {}
- self.cpu_saved_models[n] = fsdp2_sharded_save_to_cpu(self.actor_module_fsdp)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def restore_model_from_cpu(self, n):
- if n in self.cpu_saved_models:
- cpu_sharded_state, global_spec = self.cpu_saved_models[n]
- fsdp2_sharded_load_from_cpu(self.actor_module_fsdp, cpu_sharded_state, global_spec)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def clear_cpu_model(self, n):
- if n in self.cpu_saved_models:
- del self.cpu_saved_models[n]
-
-
-class DetachAsyncRolloutWorker(DetachNcclSync):
- def __init__(self, config: DictConfig, role: str):
- print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
- DetachNcclSync.__init__(self, config, role)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def set_actor_weights_info(self, weights_info):
- assert self._is_rollout
- self._weights_info = weights_info
diff --git a/verl/experimental/fully_async_policy/fully_async_main.py b/verl/experimental/fully_async_policy/fully_async_main.py
index ceeba563b18..4e9e509475f 100644
--- a/verl/experimental/fully_async_policy/fully_async_main.py
+++ b/verl/experimental/fully_async_policy/fully_async_main.py
@@ -82,45 +82,19 @@ def create_role_worker_mapping(config):
# Select worker class based on strategy
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
if use_legacy_worker_impl == "disable":
- from verl.experimental.separation.engine_workers import (
- DetachActorWorker,
- DetachAsyncRolloutWorker,
- TrainingWorker,
- )
+ from verl.experimental.separation.engine_workers import DetachActorWorker
from verl.single_controller.ray import RayWorkerGroup
+ from verl.workers.engine_workers import TrainingWorker
ray_worker_group_cls = RayWorkerGroup
CriticWorker = TrainingWorker
else:
- if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
- assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
- from verl.experimental.fully_async_policy.fsdp_workers import (
- CriticWorker,
- DetachActorWorker,
- DetachAsyncRolloutWorker,
- )
- from verl.single_controller.ray import RayWorkerGroup
-
- ray_worker_group_cls = RayWorkerGroup
-
- elif config.actor_rollout_ref.actor.strategy == "megatron":
- assert config.critic.strategy == "megatron"
- from verl.experimental.fully_async_policy.megatron_worker import (
- CriticWorker,
- DetachActorWorker,
- DetachAsyncRolloutWorker,
- )
- from verl.single_controller.ray import RayWorkerGroup
-
- ray_worker_group_cls = RayWorkerGroup
- else:
- raise NotImplementedError(f"Unsupported strategy: {config.actor_rollout_ref.actor.strategy}")
+ raise NotImplementedError("Fully async policy does not support legacy worker implementation")
train_role = Role.ActorRollout if config.async_training.use_trainer_do_validate else Role.Actor
role_worker_mapping = {
train_role: ray.remote(DetachActorWorker),
- Role.Rollout: ray.remote(DetachAsyncRolloutWorker),
Role.Critic: ray.remote(CriticWorker),
}
@@ -177,14 +151,14 @@ def _initialize_components(self, config) -> None:
print("[ASYNC MAIN] Creating FullyAsyncRollouter and FullyAsyncTrainer in parallel...")
with ThreadPoolExecutor(max_workers=2) as executor:
- rollouter_future = executor.submit(self._create_rollouter, config)
- rollouter_future.result()
-
# TODO: keep _create_rollouter and _create_trainer parallel
+ # Rollouter does not permit continuous allocation, so we allocate trainer first.
trainer_future = executor.submit(self._create_trainer, config)
- # Wait for both to complete
trainer_future.result()
+ rollouter_future = executor.submit(self._create_rollouter, config)
+ rollouter_future.result()
+
# sync total_train_steps between rollouter and trainer
total_train_steps = ray.get(self.components["rollouter"].get_total_train_steps.remote())
print(f"total_train_steps {total_train_steps}")
@@ -233,7 +207,7 @@ def _create_rollouter(self, config) -> None:
rollouter = FullyAsyncRollouter.remote(
config=config,
tokenizer=self.components["tokenizer"],
- role_worker_mapping={Role.Rollout: self.components["role_worker_mapping"][Role.Rollout]},
+ role_worker_mapping=None,
resource_pool_manager=create_resource_pool_manager(config, roles=[Role.Rollout]),
ray_worker_group_cls=self.components["ray_worker_group_cls"],
processor=self.components["processor"],
@@ -313,6 +287,9 @@ def main(config):
from time import time
start_time = time()
+ # TODO: unify rollout config with actor_rollout_ref
+ config.actor_rollout_ref.rollout.nnodes = config.rollout.nnodes
+ config.actor_rollout_ref.rollout.n_gpus_per_node = config.rollout.n_gpus_per_node
run_ppo(config, task_runner_class=FullyAsyncTaskRunner)
print(f"total time: {time() - start_time:.2f} seconds")
diff --git a/verl/experimental/fully_async_policy/fully_async_rollouter.py b/verl/experimental/fully_async_policy/fully_async_rollouter.py
index 023b23a5c56..ce21be271cf 100644
--- a/verl/experimental/fully_async_policy/fully_async_rollouter.py
+++ b/verl/experimental/fully_async_policy/fully_async_rollouter.py
@@ -32,7 +32,7 @@
)
from verl.experimental.fully_async_policy.message_queue import MessageQueueClient
from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer
-from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
+from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
from verl.trainer.ppo.utils import Role, WorkerType
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
@@ -98,8 +98,20 @@ def __init__(
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
from verl.utils.dataset.rl_dataset import collate_fn
- train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
- val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
+ train_dataset = create_rl_dataset(
+ config.data.train_files,
+ config.data,
+ tokenizer,
+ processor,
+ max_samples=config.data.get("train_max_samples", -1),
+ )
+ val_dataset = create_rl_dataset(
+ config.data.val_files,
+ config.data,
+ tokenizer,
+ processor,
+ max_samples=config.data.get("val_max_samples", -1),
+ )
train_sampler = create_rl_sampler(config.data, train_dataset)
self._validate_config()
@@ -217,6 +229,10 @@ def get_rollout_wg(self):
"""Get rollout worker group"""
return self.rollout_wg
+ def get_replicas(self):
+ """Get rollout worker group"""
+ return self.async_rollout_manager.rollout_replicas
+
def get_max_queue_size(self):
return self.max_queue_size
@@ -402,23 +418,13 @@ async def init_workers(self):
2. Worker groups for each role (actor, critic, etc.)
"""
self._init_async_objects()
- self._init_resource_pools()
self._create_worker_classes()
- self._init_worker_groups()
- self._init_models()
self._init_reward_loop()
await self._init_async_rollout_manager()
def _create_actor_rollout_classes(self):
- # only create rollout
- for role in [Role.Rollout]:
- resource_pool = self.resource_pool_manager.get_resource_pool(role)
- role_cls = RayClassWithInitArgs(
- cls=self.role_worker_mapping[role],
- config=self.config.actor_rollout_ref,
- role=str(role),
- )
- self.resource_pool_to_cls[resource_pool][str(role)] = role_cls
+ # Skip rollout creation and let agentloop handle it
+ pass
def _init_models(self):
self.rollout_wg = self.all_wg[str(Role.Rollout)]
@@ -755,12 +761,10 @@ async def pause(self):
await asyncio.gather(*self.active_tasks, return_exceptions=True)
self.active_tasks.clear()
print("[FullyAsyncRollouter][Public][Pause] All active tasks completed")
-
- # TODO use checkpoint engine for rollout clear_kv_cache
- # print("[FullyAsyncRollouter][Public][Pause] clear kv cache")
- # # Always clear KV cache to release GPU memory during weight synchronization,
- # # regardless of partial_rollout setting.
- # await self.async_rollout_manager.clear_kv_cache()
+ print("[FullyAsyncRollouter][Public][Pause] Prefix cache reset")
+ # Always clear KV cache to release GPU memory during weight synchronization,
+ # regardless of partial_rollout setting.
+ await self.async_rollout_manager.clear_kv_cache()
self.monitor_loop_trigger = False
async def resume(self, dependency_ref: ObjectRef = None):
diff --git a/verl/experimental/fully_async_policy/fully_async_trainer.py b/verl/experimental/fully_async_policy/fully_async_trainer.py
index e0019be46a5..9519c594dbd 100644
--- a/verl/experimental/fully_async_policy/fully_async_trainer.py
+++ b/verl/experimental/fully_async_policy/fully_async_trainer.py
@@ -422,6 +422,7 @@ def _fit_generate(self, batch: DataProto = None) -> DataProto:
if batch is None:
raise TrainingStopException("Training terminated: queue returned None")
self._collect_metrics_from_samples(batch, metrics)
+ batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
return batch
def _compute_old_log_prob(self, batch: DataProto):
@@ -561,8 +562,6 @@ def _save_checkpoint(self):
def load_checkpoint(self):
if self.config.trainer.resume_mode == "disable":
- # NOTE: while there is no checkpoint to load, we still need to offload the model and optimizer to CPU
- self.actor_rollout_wg.load_checkpoint(None)
return 0
# load from hdfs
@@ -578,8 +577,6 @@ def load_checkpoint(self):
# find global_step_folder
if self.config.trainer.resume_mode == "auto":
if global_step_folder is None:
- print("[FullyAsyncTrainer] Training from scratch")
- self.actor_rollout_wg.load_checkpoint(None)
return 0
else:
if self.config.trainer.resume_mode == "resume_path":
diff --git a/verl/experimental/fully_async_policy/megatron_worker.py b/verl/experimental/fully_async_policy/megatron_worker.py
deleted file mode 100644
index e0a1c2a437c..00000000000
--- a/verl/experimental/fully_async_policy/megatron_worker.py
+++ /dev/null
@@ -1,302 +0,0 @@
-# Copyright 2025 Bytedance Ltd. and/or its affiliates
-# Copyright 2025 Meituan Ltd. and/or its affiliates
-# Copyright 2025 NVIDIA Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-import time
-
-import torch
-import torch.distributed
-from omegaconf import DictConfig
-
-from verl.experimental.fully_async_policy.base_detach_sync import BaseDetachNcclSync
-from verl.experimental.fully_async_policy.megatron_utils import (
- copy_megatron_model_to_cpu,
- restore_megatron_model_from_cpu,
-)
-from verl.single_controller.base.decorator import Dispatch, register
-from verl.utils.device import (
- get_device_name,
- get_torch_device,
-)
-from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator
-from verl.workers.megatron_workers import AsyncActorRolloutRefWorker, CriticWorker
-
-logger = logging.getLogger(__file__)
-logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
-
-device_name = get_device_name()
-
-__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"]
-
-
-class DetachNcclSync(BaseDetachNcclSync, AsyncActorRolloutRefWorker):
- def __init__(self, config: DictConfig, role: str):
- BaseDetachNcclSync.__init__(self, config, role)
-
- AsyncActorRolloutRefWorker.__init__(self, config, role)
-
- def _get_actor_params(self):
- pass
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def sync_rollout_weights(self, sync_group_name="actor_rollout"):
- assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
- assert hasattr(self, "_weights_info") and self._weights_info is not None
- if self._is_actor and self._is_offload_param:
- load_megatron_model_to_gpu(self.actor_module, False)
- params_generator = self._get_actor_params_generator() if self._is_actor else None
- params = {key: tensor for key, tensor in params_generator} if params_generator is not None else None
-
- rollout_name = self.config.rollout.name
- inference_model = None
- if self._is_rollout and (not self._is_actor):
- if rollout_name == "vllm":
- inference_model = BaseDetachNcclSync.get_inference_model(self.rollout)
- from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
-
- patch_vllm_moe_model_weight_loader(inference_model)
- elif rollout_name == "sglang":
- inference_model = self.rollout._engine
- if inference_model is None:
- print("[sync_rollout_weights] Initialize server adapter engine")
-
- async def init_engine():
- if hasattr(self.rollout, "_init_server_adapter"):
- await self.rollout._init_server_adapter()
- else:
- print("[sync_rollout_weights] No _init_server_adapter method found")
- return self.rollout._engine
-
- inference_model = self._run_async_safely(init_engine())
- # For ServerAdapter, only TP rank 0 initializes the engine
- # TP rank != 0 can safely have inference_model as None
- from verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter
-
- is_server_adapter = isinstance(self.rollout, ServerAdapter)
- is_non_tp_rank = False
- if (
- is_server_adapter
- and hasattr(self.rollout, "device_mesh")
- and self.rollout.device_mesh is not None
- ):
- try:
- is_non_tp_rank = self.rollout.device_mesh["infer_tp"].get_local_rank() != 0
- except Exception:
- pass
-
- if inference_model is None and not (is_server_adapter and is_non_tp_rank):
- raise RuntimeError(
- f"Failed to initialize rollout engine. "
- f"rollout type: {type(self.rollout)}, "
- f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}"
- )
- else:
- raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
-
- if rollout_name == "sglang" and self._is_rollout:
- self._sync_sglang_weights(inference_model, params, sync_group_name)
- else:
- self._sync_vllm_weights(inference_model, params, sync_group_name)
-
- if self._is_actor and self._is_offload_param:
- offload_megatron_model_to_cpu(self.actor_module)
- get_torch_device().empty_cache()
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def save_model_to_cpu(self, n):
- if not hasattr(self, "cpu_saved_models"):
- self.cpu_saved_models = {}
- self.cpu_saved_models[n] = copy_megatron_model_to_cpu(self.actor.actor_module)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def restore_model_from_cpu(self, n):
- if n in self.cpu_saved_models:
- restore_megatron_model_from_cpu(self.actor.actor_module, self.cpu_saved_models[n])
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def clear_cpu_model(self, n):
- if n in self.cpu_saved_models:
- del self.cpu_saved_models[n]
-
- def cache_actor_weights_to_cpu(self):
- self.cpu_named_params = {}
- if self._is_actor:
- params_generator = self._get_actor_params_generator()
- local_rank = torch.distributed.get_rank()
- world_size = torch.distributed.get_world_size()
- print(f"cache_actor_weights_to_cpu, local_rank:{local_rank}, world_size:{world_size}")
- for tensor_idx, (key, tensor) in enumerate(params_generator):
- if tensor_idx % world_size == local_rank:
- self.cpu_named_params[key] = tensor.to("cpu", non_blocking=True)
- get_torch_device().synchronize()
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):
- assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
- assert hasattr(self, "_weights_info") and self._weights_info is not None
-
- # Load model to GPU
- load_start_time = time.time()
- if self._is_actor and self._is_offload_param:
- load_megatron_model_to_gpu(self.actor_module, False)
- load_duration = time.time() - load_start_time
-
- from ray.util.collective import collective
-
- # Cache actor weights to CPU and measure the time taken
- cache_start_time = time.time()
- self.cache_actor_weights_to_cpu()
- cache_end_time = time.time()
- cache_duration = cache_end_time - cache_start_time
-
- # Register the cached weights into the checkpoint engine
- self.checkpoint_engine.register_checkpoint(self._weights_info, self.cpu_named_params)
- register_end_time = time.time()
- register_duration = register_end_time - cache_end_time
- self.cpu_named_params = {}
-
- collective.barrier(group_name=sync_group_name)
- update_start_time = time.time()
-
- rollout_name = self.config.rollout.name
- inference_model = None
- if self._is_rollout and (not self._is_actor):
- if rollout_name == "vllm":
- inference_model = BaseDetachNcclSync.get_inference_model(self.rollout)
- from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
-
- patch_vllm_moe_model_weight_loader(inference_model)
- elif rollout_name == "sglang":
- inference_model = self.rollout._engine
- # For ServerAdapter, _engine might be None and needs async initialization
- if inference_model is None:
- # Initialize the server adapter engine
- print("[sync_rollout_weights] Initialize server adapter engine")
-
- async def init_engine():
- if hasattr(self.rollout, "_init_server_adapter"):
- await self.rollout._init_server_adapter()
- else:
- print("[sync_rollout_weights] No _init_server_adapter method found")
- return self.rollout._engine
-
- inference_model = self._run_async_safely(init_engine())
- # For ServerAdapter, only TP rank 0 initializes the engine
- # TP rank != 0 can safely have inference_model as None
- from verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter
-
- is_server_adapter = isinstance(self.rollout, ServerAdapter)
- is_non_tp_rank = False
- if (
- is_server_adapter
- and hasattr(self.rollout, "device_mesh")
- and self.rollout.device_mesh is not None
- ):
- try:
- is_non_tp_rank = self.rollout.device_mesh["infer_tp"].get_local_rank() != 0
- except Exception:
- pass
-
- if inference_model is None and not (is_server_adapter and is_non_tp_rank):
- raise RuntimeError(
- f"Failed to initialize rollout engine. "
- f"rollout type: {type(self.rollout)}, "
- f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}"
- )
- else:
- raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
- # Update the checkpoint with the inference model and broadcast weights
- self.checkpoint_engine.update_checkpoint(
- inference_model=inference_model,
- group_name=sync_group_name,
- overlap_broadcast_and_consume=self.config.checkpoint_engine.overlap_broadcast_and_consume,
- )
-
- update_end_time = time.time()
- update_duration = update_end_time - update_start_time
-
- collective.barrier(group_name=sync_group_name)
- offload_start_time = time.time()
- if self._is_actor and self._is_offload_param:
- offload_megatron_model_to_cpu(self.actor_module)
- offload_duration = time.time() - offload_start_time
-
- print(
- f"sync_rollout_weights_by_checkpoint finish!, rank:{torch.distributed.get_rank()},"
- f" is_actor:{self._is_actor}, is_rollout:{self._is_rollout},"
- f" total cost:{update_end_time - cache_start_time} seconds, while cache cost {cache_duration} seconds, "
- f" register cost {register_duration} seconds, update cost {update_duration} seconds"
- )
-
- if self._is_actor and self._is_offload_param:
- print(
- f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds,"
- f" offload model to cpu cost {offload_duration} seconds"
- )
-
-
-class DetachActorWorker(DetachNcclSync):
- def __init__(self, config: DictConfig, role: str):
- print("[DetachAsyncRolloutWorker] Initializing via DetachNcclSync...")
- DetachNcclSync.__init__(self, config, role)
-
- def _get_actor_params_generator(self):
- assert self._is_actor
- if self.bridge is not None:
- if self.vanilla_bridge:
- generator = self.bridge.export_weights(self.actor.actor_module)
- else:
- generator = self.bridge.export_hf_weights(self.actor.actor_module)
- else:
- generator = per_tensor_generator(
- self.actor.actor_module,
- self.actor_model_config,
- self.weight_converter,
- self.tf_config,
- self.layer_name_mapping,
- )
-
- return generator
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def get_actor_weights_info(self):
- assert self._is_actor
- if hasattr(self, "_weights_info"):
- return self._weights_info
- if self._is_offload_param:
- load_megatron_model_to_gpu(self.actor_module, False)
- params_generator = self._get_actor_params_generator()
- ret = []
- for key, tensor in params_generator:
- ret.append((key, tensor.size(), tensor.dtype))
-
- self._weights_info = ret
- # Here, we only call this function at the beginning,
- # and immediately afterwards we call sync_rollout_weights.
- # So we no longer call offload in this.
- return ret
-
-
-class DetachAsyncRolloutWorker(DetachNcclSync):
- def __init__(self, config: DictConfig, role: str):
- print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
- DetachNcclSync.__init__(self, config, role)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def set_actor_weights_info(self, weights_info):
- assert self._is_rollout
- self._weights_info = weights_info
diff --git a/verl/experimental/fully_async_policy/param_sync.py b/verl/experimental/fully_async_policy/param_sync.py
index 000568d12d6..88982efef15 100644
--- a/verl/experimental/fully_async_policy/param_sync.py
+++ b/verl/experimental/fully_async_policy/param_sync.py
@@ -18,6 +18,8 @@
import ray
from ray.util.collective import collective
+from verl.checkpoint_engine import CheckpointEngineManager
+from verl.utils.config import omega_conf_to_dataclass
from verl.utils.device import get_nccl_backend
logger = logging.getLogger(__name__)
@@ -50,11 +52,11 @@ def __init__(self, config, trainer, rollouter, mq):
# Statistics
self.current_version = 0
- self._init_weights_info()
- self._init_sync_group()
-
- if self.config.async_training.checkpoint_engine.enable:
- self._init_actor_rollout_checkpoint_engine()
+ replicas = ray.get(rollouter.get_replicas.remote())
+ checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine)
+ self.checkpoint_manager = CheckpointEngineManager(
+ config=checkpoint_engine_config, trainer=self.actor_wg, replicas=replicas
+ )
def get_current_param_version(self) -> int:
"""Get current parameter version number"""
@@ -98,22 +100,6 @@ def _init_sync_group(self):
group_name=self.sync_group_name,
)
- def _init_actor_rollout_checkpoint_engine(self):
- ray.get(
- self.actor_wg.init_checkpoint_engine(
- rank_offset=0,
- actor_num=len(self.actor_wg.workers),
- rollout_num=len(self.rollout_wg.workers),
- )
- )
- ray.get(
- self.rollout_wg.init_checkpoint_engine(
- rank_offset=len(self.actor_wg.workers),
- actor_num=len(self.actor_wg.workers),
- rollout_num=len(self.rollout_wg.workers),
- )
- )
-
def sync_weights(self, version, validate=False, global_steps=0, use_trainer_do_validate=False):
"""Sync weights between trainer and rollouter, and update parameter version"""
start_time = time.time()
@@ -130,16 +116,7 @@ def sync_weights(self, version, validate=False, global_steps=0, use_trainer_do_v
# sync weights
# For sglang, always use sync_rollout_weights instead of sync_rollout_weights_by_checkpoint
- # TODO use checkpoint engine for sglang rollout
- # rollout_name = getattr(self.config.actor_rollout_ref.rollout, "name", None)
- # use_checkpoint_engine = self.config.async_training.checkpoint_engine.enable and rollout_name != "sglang"
- # if use_checkpoint_engine:
- # self.actor_wg.sync_rollout_weights_by_checkpoint(self.sync_group_name)
- # ray.get(self.rollout_wg.sync_rollout_weights_by_checkpoint(self.sync_group_name))
- # else:
- # self.actor_wg.sync_rollout_weights(self.sync_group_name)
- # ray.get(self.rollout_wg.sync_rollout_weights(self.sync_group_name))
-
+ self.checkpoint_manager.update_weights()
end_time = time.time()
print(
f"[ParameterSynchronizer] sync_weights success. cost {end_time - start_time:.2f} seconds, "
diff --git a/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh b/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh
index e50abb4dd5d..cc936f50dc1 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh
@@ -103,7 +103,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
- actor_rollout_ref.actor.strategy=fsdp \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp \
critic.strategy=fsdp \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh
index cc1a9bb65c0..2a5eb1bb966 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh
@@ -98,7 +98,7 @@ python3 -m verl.experimental.fully_async_policy.fully_async_main \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh
index b076456bc27..ba8e6804fdb 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh
@@ -90,7 +90,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh
index 203e4c03bd9..5561208ee6d 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh
@@ -90,7 +90,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh
index 7f29d44e225..242a5117a5e 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh
@@ -92,7 +92,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh
index cd2906a0053..ee0657eace7 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh
@@ -92,7 +92,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh
index 81b50382c32..002c1206b8a 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh
@@ -90,7 +90,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh
index 7cc55f632aa..f01fb8184e7 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh
@@ -96,7 +96,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh
index ddb177681be..2b2143ffa21 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh
@@ -90,7 +90,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh
index f00b41085c5..c04a09d3266 100644
--- a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh
+++ b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh
@@ -131,7 +131,6 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- async_training.compute_prox_log_prob=True \
algorithm.rollout_correction.rollout_is=${rollout_is} \
algorithm.rollout_correction.rollout_rs=${rollout_rs} \
algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \
diff --git a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml
index cb2f8c2054c..0e4677be368 100644
--- a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml
+++ b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml
@@ -6,6 +6,9 @@ defaults:
- ppo_megatron_trainer
- _self_
+trainer:
+ use_legacy_worker_impl: disable
+
# config for the rollout (only for resource isolation)
rollout:
# Number of nodes used in the rollout
@@ -20,9 +23,9 @@ actor_rollout_ref:
free_cache_engine: False
# Must be enabled! Otherwise, log_probs cannot be calculated.
calculate_log_probs: True
- # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced.
- # TODO: Can be removed in the future once parameter synchronization is ready.
- load_format: "auto"
+
+ checkpoint_engine:
+ backend: "nccl"
# Only then will the use of log probs be correct.
# And it can be used in conjunction with other rollout_correction algorithms.
diff --git a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml
index 012745e2aa3..dc784b2ae73 100644
--- a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml
+++ b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml
@@ -6,6 +6,9 @@ defaults:
- ppo_trainer
- _self_
+trainer:
+ use_legacy_worker_impl: disable
+
# config for the rollout (only for resource isolation)
rollout:
# Number of nodes used in the rollout
@@ -20,9 +23,9 @@ actor_rollout_ref:
free_cache_engine: False
# Must be enabled! Otherwise, log_probs cannot be calculated.
calculate_log_probs: True
- # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced.
- # TODO: Can be removed in the future once parameter synchronization is ready.
- load_format: "auto"
+
+ checkpoint_engine:
+ backend: "nccl"
# Only then will the use of log probs be correct.
# And it can be used in conjunction with other rollout_correction algorithms.
diff --git a/verl/experimental/one_step_off_policy/fsdp_workers.py b/verl/experimental/one_step_off_policy/fsdp_workers.py
deleted file mode 100644
index 67584a198bc..00000000000
--- a/verl/experimental/one_step_off_policy/fsdp_workers.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# Copyright 2025 Bytedance Ltd. and/or its affiliates
-# Copyright 2025 Meituan Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-
-import torch
-import torch.distributed
-from omegaconf import DictConfig
-from ray.util.collective import collective
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-
-from verl.experimental.one_step_off_policy.distributed_utils import vllm_stateless_init_process_group
-from verl.single_controller.base.decorator import Dispatch, register
-from verl.utils.device import (
- get_device_name,
- get_torch_device,
-)
-from verl.utils.fsdp_utils import (
- fsdp_version,
- load_fsdp_model_to_gpu,
- offload_fsdp_model_to_cpu,
-)
-from verl.utils.ray_utils import get_event_loop
-from verl.workers.fsdp_workers import (
- ActorRolloutRefWorker,
- AsyncActorRolloutRefWorker,
- CriticWorker,
-)
-
-logger = logging.getLogger(__file__)
-logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
-
-device_name = get_device_name()
-
-__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"]
-
-
-class DetachSync(AsyncActorRolloutRefWorker):
- def _get_actor_params(self):
- pass
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size):
- rank = torch.distributed.get_rank() + rank_offset
- self._weight_sync_group = vllm_stateless_init_process_group(
- master_address,
- master_port,
- rank,
- world_size,
- get_torch_device().current_device(),
- )
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def sync_rollout_weights(self):
- assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
- assert hasattr(self, "_weights_info") and self._weights_info is not None
-
- if self._is_actor and self._is_offload_param:
- load_fsdp_model_to_gpu(self.actor_module_fsdp)
- params = self._get_actor_params() if self._is_actor else None
-
- rollout_name = self.config.rollout.name
- if self._is_rollout:
- if rollout_name == "vllm":
- from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
-
- inference_model = self.rollout.inference_engine.worker.model_runner.model
- patch_vllm_moe_model_weight_loader(inference_model)
- elif rollout_name == "sglang":
- inference_model = self.rollout._engine
- else:
- raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
- loop = get_event_loop()
- for key, shape, dtype in self._weights_info:
- tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
- if self._is_actor:
- assert key in params
- origin_data = params[key]
- if hasattr(origin_data, "full_tensor"):
- origin_data = origin_data.full_tensor()
- if torch.distributed.get_rank() == 0:
- tensor.copy_(origin_data)
-
- if device_name == "npu":
- self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())
- else:
- collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
-
- if self._is_rollout:
- if rollout_name == "vllm":
- inference_model.load_weights([(key, tensor)])
- elif rollout_name == "sglang":
- # first_rank_in_node = self._tp_rank % tp_size_per_node == 0,
- # Only the first rank within each node (i.e., the local rank is 0) initializes the engine;
- # engines for other ranks are set to None.
-
- if inference_model is not None:
- loop.run_until_complete(self.update_weights(inference_model, [(key, tensor)]))
-
- if self._is_actor and self._is_offload_param:
- offload_fsdp_model_to_cpu(self.actor_module_fsdp)
- get_torch_device().empty_cache()
-
- async def update_weights(self, inference_engine, params):
- from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
-
- await sgl_update_weights(
- engine=inference_engine,
- params_batch=params,
- device_mesh_key="infer_tp",
- device_mesh=self.rollout_device_mesh,
- )
-
- if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
- await inference_engine.flush_cache()
-
-
-class DetachActorWorker(DetachSync):
- def _get_actor_params(self):
- assert self._is_actor
- params = self.actor_module_fsdp.state_dict()
- from verl.utils.model import convert_weight_keys
-
- params = convert_weight_keys(
- params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
- )
- return params
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def get_actor_weights_info(self):
- assert self._is_actor
- if hasattr(self, "_weights_info"):
- return self._weights_info
- if fsdp_version(self.actor_module_fsdp) == 1:
- from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
-
- FSDP.set_state_dict_type(
- self.actor_module_fsdp,
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
- state_dict_config=ShardedStateDictConfig(),
- )
- params = self._get_actor_params()
- ret = []
- for key, tensor in params.items():
- ret.append((key, tensor.size(), tensor.dtype))
- self._weights_info = ret
- return ret
-
-
-class DetachAsyncRolloutWorker(DetachSync):
- def __init__(self, config: DictConfig, role: str):
- print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
- ActorRolloutRefWorker.__init__(self, config, role)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def set_actor_weights_info(self, weights_info):
- assert self._is_rollout
- self._weights_info = weights_info
diff --git a/verl/experimental/one_step_off_policy/main_ppo.py b/verl/experimental/one_step_off_policy/main_ppo.py
index 6a7405d6c0f..0c6ecaedf0e 100644
--- a/verl/experimental/one_step_off_policy/main_ppo.py
+++ b/verl/experimental/one_step_off_policy/main_ppo.py
@@ -64,10 +64,6 @@ def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager:
assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0"
assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0"
- rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes
- resource_pool_spec["rollout_pool"] = rollout_pool
- mapping[Role.Rollout] = "rollout_pool"
-
return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
@@ -81,48 +77,15 @@ def create_role_worker_mapping(config):
Returns:
dict: Mapping from roles to worker classes
"""
- # Select worker class based on strategy
- use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
- if use_legacy_worker_impl == "disable":
- from verl.experimental.separation.engine_workers import (
- DetachActorWorker,
- DetachAsyncRolloutWorker,
- TrainingWorker,
- )
- from verl.single_controller.ray import RayWorkerGroup
-
- ray_worker_group_cls = RayWorkerGroup
-
- CriticWorker = TrainingWorker
- else:
- if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
- assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
- from verl.experimental.one_step_off_policy.fsdp_workers import (
- CriticWorker,
- DetachActorWorker,
- DetachAsyncRolloutWorker,
- )
- from verl.single_controller.ray import RayWorkerGroup
-
- ray_worker_group_cls = RayWorkerGroup
-
- elif config.actor_rollout_ref.actor.strategy == "megatron":
- assert config.critic.strategy == "megatron"
- from verl.experimental.one_step_off_policy.megatron_workers import (
- CriticWorker,
- DetachActorWorker,
- DetachAsyncRolloutWorker,
- )
- from verl.single_controller.ray import RayWorkerGroup
-
- ray_worker_group_cls = RayWorkerGroup
- else:
- raise NotImplementedError(f"Unsupported strategy: {config.actor_rollout_ref.actor.strategy}")
+ from verl.experimental.separation.engine_workers import DetachActorWorker
+ from verl.single_controller.ray import RayWorkerGroup
+ from verl.workers.engine_workers import TrainingWorker
+
+ ray_worker_group_cls = RayWorkerGroup
role_worker_mapping = {
Role.Actor: ray.remote(DetachActorWorker),
- Role.Rollout: ray.remote(DetachAsyncRolloutWorker),
- Role.Critic: ray.remote(CriticWorker),
+ Role.Critic: ray.remote(TrainingWorker),
}
# Add reference policy (if KL loss or reward is required)
@@ -219,6 +182,10 @@ def main(config):
# Automatically set `config.trainer.device = npu` when running on Ascend NPU.
auto_set_device(config)
+ # TODO: unify rollout config with actor_rollout_ref
+ config.actor_rollout_ref.rollout.nnodes = config.rollout.nnodes
+ config.actor_rollout_ref.rollout.n_gpus_per_node = config.rollout.n_gpus_per_node
+
run_ppo(config, task_runner_class=OneStepTaskRunner)
print(f"total time: {time() - start_time:.2f} seconds")
diff --git a/verl/experimental/one_step_off_policy/megatron_workers.py b/verl/experimental/one_step_off_policy/megatron_workers.py
deleted file mode 100644
index d5c0a18cfea..00000000000
--- a/verl/experimental/one_step_off_policy/megatron_workers.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Copyright 2025 Bytedance Ltd. and/or its affiliates
-# Copyright 2025 Meituan Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-
-import torch
-import torch.distributed
-from omegaconf import DictConfig
-from ray.util.collective import collective
-
-from verl.experimental.one_step_off_policy.distributed_utils import vllm_stateless_init_process_group
-from verl.single_controller.base.decorator import Dispatch, register
-from verl.utils.device import (
- get_device_name,
- get_torch_device,
-)
-from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu
-from verl.utils.ray_utils import get_event_loop
-from verl.workers.megatron_workers import (
- ActorRolloutRefWorker,
- AsyncActorRolloutRefWorker,
- CriticWorker,
-)
-
-logger = logging.getLogger(__file__)
-logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
-
-device_name = get_device_name()
-
-__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"]
-
-
-class DetachSync(AsyncActorRolloutRefWorker):
- def _get_actor_params(self):
- pass
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size):
- rank = torch.distributed.get_rank() + rank_offset
- self._weight_sync_group = vllm_stateless_init_process_group(
- master_address,
- master_port,
- rank,
- world_size,
- get_torch_device().current_device(),
- )
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def sync_rollout_weights(self):
- assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
- assert hasattr(self, "_weights_info") and self._weights_info is not None
-
- params_generator = self._get_actor_params_generator() if self._is_actor else None
-
- if self._is_actor and self._is_offload_param:
- load_megatron_model_to_gpu(self.actor_module)
-
- rollout_name = self.config.rollout.name
- if self._is_rollout:
- if rollout_name == "vllm":
- from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
-
- inference_model = self.rollout.inference_engine.worker.model_runner.model
- patch_vllm_moe_model_weight_loader(inference_model)
- elif rollout_name == "sglang":
- inference_model = self.rollout._engine
- else:
- raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
-
- loop = get_event_loop()
- for key, shape, dtype in self._weights_info:
- if self._is_actor:
- weight_key, weight = next(params_generator)
- assert key == weight_key
- assert shape == weight.size()
- assert dtype == weight.dtype
-
- tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
- if self._is_actor and torch.distributed.get_rank() == 0:
- tensor.copy_(weight)
-
- if device_name == "npu":
- self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())
- else:
- collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
-
- if self._is_rollout:
- if rollout_name == "vllm":
- inference_model.load_weights([(key, tensor)])
- elif rollout_name == "sglang":
- # first_rank_in_node = self._tp_rank % tp_size_per_node == 0,
- # Only the first rank within each node (i.e., the local rank is 0) initializes the engine;
- # engines for other ranks are set to None.
-
- if inference_model is not None:
- loop.run_until_complete(self.update_weights(inference_model, [(key, tensor)]))
-
- if self._is_actor and self._is_offload_param:
- offload_megatron_model_to_cpu(self.actor_module)
-
- async def update_weights(self, inference_engine, params):
- from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
-
- await sgl_update_weights(
- engine=inference_engine,
- params_batch=params,
- device_mesh_key="infer_tp",
- device_mesh=self.rollout_device_mesh,
- )
-
- if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
- await inference_engine.flush_cache()
-
-
-class DetachActorWorker(DetachSync):
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def _get_actor_params_generator(self):
- assert self._is_actor
- from verl.models.mcore import get_mcore_weight_converter
- from verl.utils.megatron_utils import per_tensor_generator
-
- layer_name_mapping = {
- "qkv_layer_name": "self_attention.linear_qkv.",
- "gate_proj_layer_name": "linear_fc1.",
- }
- weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
- generator = per_tensor_generator(
- self.actor.actor_module,
- self.actor_model_config,
- weight_converter,
- self.tf_config,
- layer_name_mapping,
- )
- return generator
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def get_actor_weights_info(self):
- assert self._is_actor
- if hasattr(self, "_weights_info"):
- return self._weights_info
- if self._is_offload_param:
- load_megatron_model_to_gpu(self.actor_module)
- params_generator = self._get_actor_params_generator()
- ret = []
- for key, tensor in params_generator:
- ret.append((key, tensor.size(), tensor.dtype))
-
- self._weights_info = ret
- # Here, we only call this function at the beginning,
- # and immediately afterwards we call sync_rollout_weights.
- # So we no longer call offload in this.
- return ret
-
-
-class DetachAsyncRolloutWorker(DetachSync):
- def __init__(self, config: DictConfig, role: str):
- print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
- ActorRolloutRefWorker.__init__(self, config, role)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def set_actor_weights_info(self, weights_info):
- assert self._is_rollout
- self._weights_info = weights_info
diff --git a/verl/experimental/one_step_off_policy/ray_trainer.py b/verl/experimental/one_step_off_policy/ray_trainer.py
index a0bdc0150b3..144632dead5 100644
--- a/verl/experimental/one_step_off_policy/ray_trainer.py
+++ b/verl/experimental/one_step_off_policy/ray_trainer.py
@@ -27,7 +27,6 @@
import ray
import torch
from omegaconf import OmegaConf
-from ray.util.collective import collective
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
@@ -88,6 +87,8 @@ def __init__(
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
assert not self.hybrid_engine
+ # Skip rollout worker mapping and let agentloop create it.
+ role_worker_mapping.pop(Role.Rollout, None)
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = need_reference_policy(self.config)
@@ -138,14 +139,8 @@ def __init__(
self.reward_tensor = None
self.reward_extra_infos_dict = {}
- def _validate(self):
- self.actor_rollout_wg = self.rollout_wg
- ret = super()._validate()
- self.actor_rollout_wg = self.actor_wg
- return ret
-
def _create_actor_rollout_classes(self):
- for role in [Role.Actor, Role.Rollout]:
+ for role in [Role.Actor]:
resource_pool = self.resource_pool_manager.get_resource_pool(role)
role_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[role],
@@ -169,13 +164,8 @@ def _init_models(self):
self.rm_wg.init_model()
self.actor_wg = self.all_wg[str(Role.Actor)]
- self.rollout_wg = self.all_wg[str(Role.Rollout)]
self.actor_wg.init_model()
- self.rollout_wg.init_model()
self.actor_rollout_wg = self.actor_wg
- weights_info = self.actor_wg.get_actor_weights_info()[0]
- self.rollout_wg.set_actor_weights_info(weights_info)
- self._create_weight_sync_group()
def _init_async_rollout_manager(self):
# infrastructure overview: https://verl.readthedocs.io/en/latest/advance/reward_loop.html#architecture-design
@@ -192,47 +182,10 @@ def _init_async_rollout_manager(self):
from verl.experimental.one_step_off_policy.agent_loop import OneStepOffAgentLoopManager
self.async_rollout_mode = True
- self.async_rollout_manager = OneStepOffAgentLoopManager(
- config=self.config, worker_group=self.rollout_wg, reward_loop_worker_handles=reward_loop_worker_handles
+ self.async_rollout_manager = OneStepOffAgentLoopManager.create(
+ config=self.config, reward_loop_worker_handles=reward_loop_worker_handles
)
- def _create_weight_sync_group(self):
- from verl.utils.device import get_nccl_backend
-
- actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
- n_workers = len(actor_rollout_workers)
-
- if self.device_name == "npu":
- master_address = ray.get(self.actor_wg.workers[0]._get_node_ip.remote()).strip("[]")
- master_port = ray.get(self.actor_wg.workers[0]._get_free_port.remote())
- self.actor_wg.create_weight_sync_group(
- master_address,
- master_port,
- 0,
- n_workers,
- )
- ray.get(
- self.rollout_wg.create_weight_sync_group(
- master_address,
- master_port,
- len(self.actor_wg.workers),
- n_workers,
- )
- )
- else:
- # Create Ray collective group for fallback communication
- collective.create_collective_group(
- actor_rollout_workers,
- n_workers,
- list(range(0, n_workers)),
- backend=get_nccl_backend(),
- group_name="actor_rollout",
- )
-
- def sync_rollout_weights(self):
- self.actor_wg.sync_rollout_weights()
- ray.get(self.rollout_wg.sync_rollout_weights())
-
def _create_continuous_iterator(self):
"""
Create a continuous data iterator across epoch
@@ -434,6 +387,7 @@ async def _fit_generate(self, batch_data_future, continuous_iterator):
with marked_timer("gen", timing_raw, color="red"):
_metrics, _timing_raw, epoch, batch, future_reward = await batch_data_future
+ batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
timing_raw.update(batch.meta_info["timing"])
timing_raw.update(_timing_raw)
metrics.update(_metrics)
@@ -452,8 +406,3 @@ async def _fit_generate(self, batch_data_future, continuous_iterator):
batch_data_future = None
return batch, batch_data_future
-
- def _fit_update_weights(self):
- # TODO: use checkpoint engine to update weight
- # self.sync_rollout_weights()
- pass
diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_4_12.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_4_12.sh
index 7101009e07b..cbefe87424b 100644
--- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_4_12.sh
+++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_4_12.sh
@@ -73,7 +73,7 @@ python3 -m verl.experimental.one_step_off_policy.main_ppo \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64.sh
index 2cd5ae46d9c..c35513cf9f2 100644
--- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64.sh
+++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64.sh
@@ -75,7 +75,7 @@ python -m verl.experimental.one_step_off_policy.main_ppo \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64_ris.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64_ris.sh
index 8775bdfa0f5..10ce9122269 100644
--- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64_ris.sh
+++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64_ris.sh
@@ -85,7 +85,7 @@ python -m verl.experimental.one_step_off_policy.main_ppo \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_colocate.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_colocate.sh
index b44fd6b25e6..a5c6ee87143 100644
--- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_colocate.sh
+++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_colocate.sh
@@ -68,7 +68,7 @@ python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_4_12.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_4_12.sh
index e56d2f90dd5..2725bb5bc3d 100644
--- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_4_12.sh
+++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_4_12.sh
@@ -73,7 +73,7 @@ python3 -m verl.experimental.one_step_off_policy.main_ppo \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_colocate.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_colocate.sh
index 3c18460f1e2..5ccceec5f9d 100644
--- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_colocate.sh
+++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_colocate.sh
@@ -68,7 +68,7 @@ python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
diff --git a/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6.sh b/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6.sh
index b2dfa578ed7..facabdf58e8 100644
--- a/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6.sh
+++ b/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6.sh
@@ -26,7 +26,7 @@ python3 -m verl.experimental.one_step_off_policy.main_ppo \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
diff --git a/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh b/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh
index 1f5f72e6bcc..5c959f49961 100644
--- a/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh
+++ b/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh
@@ -26,7 +26,7 @@ python3 -m verl.experimental.one_step_off_policy.main_ppo \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
diff --git a/verl/experimental/one_step_off_policy/shell/grpo_3b_gsm8k_fsdp2_2_6.sh b/verl/experimental/one_step_off_policy/shell/grpo_3b_gsm8k_fsdp2_2_6.sh
index b94a66f588b..c5c5eb11d2a 100644
--- a/verl/experimental/one_step_off_policy/shell/grpo_3b_gsm8k_fsdp2_2_6.sh
+++ b/verl/experimental/one_step_off_policy/shell/grpo_3b_gsm8k_fsdp2_2_6.sh
@@ -25,7 +25,7 @@ python3 -m verl.experimental.one_step_off_policy.main_ppo \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
- actor_rollout_ref.actor.strategy=fsdp2 \
+ actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
diff --git a/verl/experimental/reward_loop/reward_loop.py b/verl/experimental/reward_loop/reward_loop.py
index 9261db23719..712cd9bde7a 100644
--- a/verl/experimental/reward_loop/reward_loop.py
+++ b/verl/experimental/reward_loop/reward_loop.py
@@ -38,7 +38,6 @@
def migrate_legacy_reward_impl(config):
"""
Migrate the legacy reward model implementation to the new one.
- This is a temporary fix. A more robust one will be added.
"""
# 1. reward workers migration
# config.reward_model.num_workers -> config.reward.num_workers
@@ -49,7 +48,7 @@ def migrate_legacy_reward_impl(config):
# config.reward_model.reward_manager -> config.reward.reward_manager
if config.reward_model.reward_manager is not None:
config.reward.reward_manager.name = config.reward_model.reward_manager
- if config.reward_model.get("reward_loop_source") is not None:
+ if config.reward_model.reward_loop_source is not None:
config.reward.reward_manager.source = config.reward_model.reward_loop_source
config.reward.reward_manager.module.path = config.reward_model.reward_loop_module_path
config.reward.reward_manager.module.name = config.reward_model.reward_loop_class_name
@@ -64,19 +63,29 @@ def migrate_legacy_reward_impl(config):
for key in ["enable", "enable_resource_pool", "n_gpus_per_node", "nnodes"]:
if config.reward_model.get(key) is not None:
config.reward.reward_model[key] = config.reward_model[key]
- # for dapo reward kwargs
+ if config.reward_model.model.path is not None:
+ config.reward.reward_model.model_path = config.reward_model.model.path
+ # config.reward_model.reward_kwargs -> config.reward.reward_kwargs (for dapo algo)
if config.reward_model.get("reward_kwargs") is not None:
- with open_dict(config.reward.reward_model):
- config.reward.reward_model["reward_kwargs"] = config.reward_model["reward_kwargs"]
+ with open_dict(config.reward):
+ config.reward["reward_kwargs"] = config.reward_model["reward_kwargs"]
+ # config.reward_model.rollout -> config.reward.reward_model.rollout
legacy_rollout = config.reward_model.rollout
- if not all(v is None for v in legacy_rollout.values()):
- config.reward.reward_model.rollout = legacy_rollout
+ for key in legacy_rollout.keys():
+ if legacy_rollout[key] is not None:
+ config.reward.reward_model.rollout[key] = legacy_rollout[key]
# 5. sandbox_fusion migration
# config.sandbox_fusion -> reward.sandbox_fusion
if not all(v is None for v in config.sandbox_fusion.values()):
config.reward.sandbox_fusion = config.sandbox_fusion
+ # 6. delete legacy config from configs
+ with open_dict(config):
+ del config.reward_model
+ del config.custom_reward_function
+ del config.sandbox_fusion
+
return config
@@ -222,12 +231,10 @@ async def compute_score_disrm(self, data: DataProto) -> dict:
engine_name = self.config.reward.reward_model.rollout.name
model_name = self.config.reward.reward_model.model_path
if engine_name == "vllm":
- # TODO (dyy): the "activation" has been changed to "use_activation" in vllm 0.11.2
payloads = {
"model": model_name,
"input": disrm_prompt,
- "activation": False,
- # "add_special_tokens": False, # vllm >= 0.11.2
+ "use_activation": False,
}
output = await self._post_request(payloads, "classify")
rm_score = output["data"][-1]["probs"][-1]
diff --git a/verl/experimental/separation/engine_workers.py b/verl/experimental/separation/engine_workers.py
index 5e98052bb79..0f8062ff888 100644
--- a/verl/experimental/separation/engine_workers.py
+++ b/verl/experimental/separation/engine_workers.py
@@ -15,290 +15,58 @@
import logging
import os
-import time
-import torch
-import torch.distributed
from omegaconf import DictConfig
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from verl.experimental.fully_async_policy.base_detach_sync import BaseDetachNcclSync
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.device import (
get_device_name,
- get_torch_device,
)
-from verl.utils.fsdp_utils import fsdp_version
-from verl.utils.megatron_utils import per_tensor_generator
-from verl.workers.engine_workers import ActorRolloutRefWorker, TrainingWorker
+from verl.workers.engine_workers import ActorRolloutRefWorker
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
device_name = get_device_name()
-__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "TrainingWorker"]
+__all__ = ["DetachActorWorker"]
-class DetachNcclSync(BaseDetachNcclSync, ActorRolloutRefWorker):
- def __init__(self, config: DictConfig, role: str):
- BaseDetachNcclSync.__init__(self, config, role)
- ActorRolloutRefWorker.__init__(self, config, role)
-
- def _get_actor_params(self):
- pass
-
- def load_model_to_gpu(self):
- if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
- from verl.utils.fsdp_utils import load_fsdp_model_to_gpu
-
- load_fsdp_model_to_gpu(self.actor_module_fsdp)
- elif self.config.actor_rollout_ref.actor.strategy == "megatron":
- from verl.utils.megatron_utils import load_megatron_model_to_gpu
-
- load_megatron_model_to_gpu(self.actor_module, False)
- else:
- raise NotImplementedError(f"Unsupported strategy: {self.config.actor_rollout_ref.actor.strategy}")
-
- def offload_model_to_cpu(self):
- if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
- from verl.utils.fsdp_utils import offload_fsdp_model_to_cpu
-
- offload_fsdp_model_to_cpu(self.actor_module_fsdp)
- elif self.config.actor_rollout_ref.actor.strategy == "megatron":
- from verl.utils.megatron_utils import offload_megatron_model_to_cpu
-
- offload_megatron_model_to_cpu(self.actor_module)
- else:
- raise NotImplementedError(f"Unsupported strategy: {self.config.actor_rollout_ref.actor.strategy}")
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def sync_rollout_weights(self, sync_group_name="actor_rollout"):
- # TODO: Refator this function for the chekpoint engine
- assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
- assert hasattr(self, "_weights_info") and self._weights_info is not None
-
- if self._is_actor and self.engine._is_offload_param:
- self.load_model_to_gpu()
-
- if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
- params = self._get_actor_params() if self._is_actor else None
-
- elif self.config.actor_rollout_ref.actor.strategy == "megatron":
- params_generator = self._get_actor_params_generator() if self._is_actor else None
- params = {key: tensor for key, tensor in params_generator} if params_generator is not None else None
-
- else:
- raise NotImplementedError(f"Unsupported strategy: {self.config.actor_rollout_ref.actor.strategy}")
-
- rollout_name = self.config.rollout.name
-
- inference_model = None
- if self._is_rollout and (not self._is_actor):
- if rollout_name == "vllm":
- inference_model = BaseDetachNcclSync.get_inference_model(self.rollout)
-
- from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
-
- patch_vllm_moe_model_weight_loader(inference_model)
- elif rollout_name == "sglang":
- inference_model = self.rollout._engine
- # For ServerAdapter, _engine might be None and needs async initialization
- if inference_model is None:
- # Initialize the server adapter engine
- print("[sync_rollout_weights] Initialize server adapter engine")
-
- async def init_engine():
- if hasattr(self.rollout, "_init_server_adapter"):
- await self.rollout._init_server_adapter()
- else:
- print("[sync_rollout_weights] No _init_server_adapter method found")
- return self.rollout._engine
-
- inference_model = self._run_async_safely(init_engine())
- if inference_model is None:
- raise RuntimeError(
- f"Failed to initialize rollout engine. "
- f"rollout type: {type(self.rollout)}, "
- f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}"
- )
- else:
- raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
-
- if rollout_name == "sglang" and self._is_rollout:
- self._sync_sglang_weights(inference_model, params, sync_group_name)
- else:
- self._sync_vllm_weights(inference_model, params, sync_group_name)
-
- if self._is_actor and self.engine._is_offload_param:
- self.offload_model_to_cpu()
- get_torch_device().empty_cache()
-
- def cache_actor_weights_to_cpu(self):
- # TODO: Refator this function for the chekpoint engine
- self.cpu_named_params = {}
- local_rank = torch.distributed.get_rank()
- world_size = torch.distributed.get_world_size()
- if self._is_actor:
- if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
- params = self._get_actor_params()
- for tensor_idx, (key, _, _) in enumerate(self._weights_info):
- origin_data = params[key]
- if hasattr(origin_data, "full_tensor"):
- origin_data = origin_data.full_tensor()
-
- if tensor_idx % world_size == local_rank:
- self.cpu_named_params[key] = origin_data.to("cpu", non_blocking=True)
-
- elif self.config.actor_rollout_ref.actor.strategy == "megatron":
- params_generator = self._get_actor_params_generator()
- print(f"cache_actor_weights_to_cpu, local_rank:{local_rank}, world_size:{world_size}")
- for tensor_idx, (key, tensor) in enumerate(params_generator):
- if tensor_idx % world_size == local_rank:
- self.cpu_named_params[key] = tensor.to("cpu", non_blocking=True)
- else:
- raise NotImplementedError(f"Unsupported strategy: {self.config.actor_rollout_ref.actor.strategy}")
-
- get_torch_device().synchronize()
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
- def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):
- # TODO: Refator this function for the chekpoint engine
- assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
- assert hasattr(self, "_weights_info") and self._weights_info is not None
-
- # Load model to GPU
- load_start_time = time.time()
- if self._is_actor and self.engine._is_offload_param:
- self.load_model_to_gpu()
- load_duration = time.time() - load_start_time
-
- from ray.util.collective import collective
-
- # Cache actor weights to CPU and measure the time taken
- cache_start_time = time.time()
- self.cache_actor_weights_to_cpu()
- cache_end_time = time.time()
- cache_duration = cache_end_time - cache_start_time
-
- # Register the cached weights into the checkpoint engine
- self.checkpoint_engine.register_checkpoint(self._weights_info, self.cpu_named_params)
- register_end_time = time.time()
- register_duration = register_end_time - cache_end_time
- self.cpu_named_params = {}
-
- collective.barrier(group_name=sync_group_name)
- update_start_time = time.time()
-
- rollout_name = self.config.rollout.name
- inference_model = None
- if self._is_rollout and (not self._is_actor):
- if rollout_name == "vllm":
- inference_model = BaseDetachNcclSync.get_inference_model(self.rollout)
- from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
-
- patch_vllm_moe_model_weight_loader(inference_model)
- elif rollout_name == "sglang":
- if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
- raise NotImplementedError(
- "Fully async sglang backend does not support "
- f"actor strategy: {self.config.actor_rollout_ref.actor.strategy}"
- )
+class DetachActorWorker(ActorRolloutRefWorker):
+ """
+ A worker class that extends ActorRolloutRefWorker to support detaching and restoring the actor model.
- inference_model = self.rollout._engine
- # For ServerAdapter, _engine might be None and needs async initialization
- if inference_model is None:
- # Initialize the server adapter engine
- print("[sync_rollout_weights] Initialize server adapter engine")
+ This worker facilitates saving the model state to CPU and restoring it, enabling efficient
+ resource management and checkpointing in distributed training. It currently supports
+ FSDP, FSDP2, and Megatron strategies.
+ """
- async def init_engine():
- if hasattr(self.rollout, "_init_server_adapter"):
- await self.rollout._init_server_adapter()
- else:
- print("[sync_rollout_weights] No _init_server_adapter method found")
- return self.rollout._engine
-
- inference_model = self._run_async_safely(init_engine())
- if inference_model is None:
- raise RuntimeError(
- f"Failed to initialize rollout engine. "
- f"rollout type: {type(self.rollout)}, "
- f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}"
- )
- else:
- raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
-
- # Update the checkpoint with the inference model and broadcast weights
- self.checkpoint_engine.update_checkpoint(
- inference_model=inference_model,
- group_name=sync_group_name,
- overlap_broadcast_and_consume=self.config.checkpoint_engine.overlap_broadcast_and_consume,
- )
-
- update_end_time = time.time()
- update_duration = update_end_time - update_start_time
-
- if self.config.actor_rollout_ref.actor.strategy == "megatron":
- collective.barrier(group_name=sync_group_name)
-
- offload_start_time = time.time()
- if self._is_actor and self.engine._is_offload_param:
- self.offload_model_to_cpu()
- offload_duration = time.time() - offload_start_time
-
- print(
- f"sync_rollout_weights_by_checkpoint finish!, rank:{torch.distributed.get_rank()},"
- f" is_actor:{self._is_actor}, is_rollout:{self._is_rollout},"
- f" total cost:{update_end_time - cache_start_time} seconds, while cache cost {cache_duration} seconds, "
- f" register cost {register_duration} seconds, update cost {update_duration} seconds"
- )
-
- if self._is_actor and self.engine._is_offload_param:
- print(
- f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds,"
- f" offload model to cpu cost {offload_duration} seconds"
- )
-
-
-class DetachActorWorker(DetachNcclSync):
def __init__(self, config: DictConfig, role: str):
- print("[DetachAsyncRolloutWorker] Initializing via DetachNcclSync...")
- DetachNcclSync.__init__(self, config, role)
+ """
+ Initialize the DetachActorWorker.
+
+ Args:
+ config: Configuration dictionary.
+ role: The role of the worker (e.g., 'actor', 'rollout', 'ref').
+ """
+ ActorRolloutRefWorker.__init__(self, config, role)
self._strategy_handlers = None
self.copy_handler, self.restore_handler = self._get_strategy_handlers()
- def _get_actor_params(self):
- # TODO: Refator this function for the chekpoint engine
- assert self._is_actor
- params = self.actor_module_fsdp.state_dict()
- from verl.utils.model import convert_weight_keys
-
- params = convert_weight_keys(
- params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
- )
- return params
-
- def _get_actor_params_generator(self):
- # TODO: Refator this function for the chekpoint engine
- assert self._is_actor
- if self.bridge is not None:
- generator = self.bridge.export_weights(self.actor.actor_module)
- else:
- generator = per_tensor_generator(
- self.actor.actor_module,
- self.actor_model_config,
- self.weight_converter,
- self.tf_config,
- self.layer_name_mapping,
- )
+ def _get_strategy_handlers(self):
+ """
+ Get the strategy-specific handlers for saving and restoring the model.
- return generator
+ Returns:
+ tuple: A tuple containing (save_handler, restore_handler).
- def _get_strategy_handlers(self):
+ Raises:
+ NotImplementedError: If the strategy is not supported.
+ """
if self._strategy_handlers is not None:
return self._strategy_handlers
- strategy = self.config.actor_rollout_ref.actor.strategy
+ strategy = self.config.actor.strategy
if strategy in ["fsdp", "fsdp2"]:
from verl.experimental.fully_async_policy.fsdp2_utils import (
@@ -319,47 +87,14 @@ def _get_strategy_handlers(self):
return self._strategy_handlers
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def get_actor_weights_info(self):
- # TODO: Refator this function for the chekpoint engine
- assert self._is_actor
- if hasattr(self, "_weights_info"):
- return self._weights_info
-
- ret = []
-
- if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
- if fsdp_version(self.actor_module_fsdp) == 1:
- from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
-
- FSDP.set_state_dict_type(
- self.actor_module_fsdp,
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
- state_dict_config=ShardedStateDictConfig(),
- )
- params = self._get_actor_params()
-
- for key, tensor in params.items():
- ret.append((key, tensor.size(), tensor.dtype))
-
- elif self.config.actor_rollout_ref.actor.strategy == "megatron":
- if self.engine._is_offload_param:
- self.load_model_to_gpu()
-
- params_generator = self._get_actor_params_generator()
-
- for key, tensor in params_generator:
- ret.append((key, tensor.size(), tensor.dtype))
-
- else:
- raise NotImplementedError(f"Unsupported strategy: {self.config.actor_rollout_ref.actor.strategy}")
-
- self._weights_info = ret
-
- return ret
-
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_model_to_cpu(self, n):
+ """
+ Save the current model state to CPU memory.
+
+ Args:
+ n: Identifier/Key for the saved model state.
+ """
if not hasattr(self, "cpu_saved_models"):
self.cpu_saved_models = {}
@@ -367,8 +102,14 @@ def save_model_to_cpu(self, n):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def restore_model_from_cpu(self, n):
+ """
+ Restore the model state from CPU memory.
+
+ Args:
+ n: Identifier/Key for the saved model state to restore.
+ """
if n in self.cpu_saved_models:
- strategy = self.config.actor_rollout_ref.actor.strategy
+ strategy = self.config.actor.strategy
if strategy in ["fsdp", "fsdp2"]:
cpu_sharded_state, global_spec = self.cpu_saved_models[n]
@@ -378,16 +119,11 @@ def restore_model_from_cpu(self, n):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def clear_cpu_model(self, n):
+ """
+ Clear the saved model state from CPU memory.
+
+ Args:
+ n: Identifier/Key for the saved model state to remove.
+ """
if n in self.cpu_saved_models:
del self.cpu_saved_models[n]
-
-
-class DetachAsyncRolloutWorker(DetachNcclSync):
- def __init__(self, config: DictConfig, role: str):
- print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
- DetachNcclSync.__init__(self, config, role)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def set_actor_weights_info(self, weights_info):
- assert self._is_rollout
- self._weights_info = weights_info
diff --git a/verl/experimental/separation/ray_trainer.py b/verl/experimental/separation/ray_trainer.py
index 56945445f6e..ca850b0590d 100644
--- a/verl/experimental/separation/ray_trainer.py
+++ b/verl/experimental/separation/ray_trainer.py
@@ -103,6 +103,7 @@ def __init__(
# reward message
self.reward_tensor = None
self.reward_extra_infos_dict = {}
+ self.checkpoint_manager = None
def init_workers(self):
"""Initialize distributed training workers using Ray backend.
@@ -119,7 +120,7 @@ def init_workers(self):
self._init_async_rollout_manager()
self.checkpoint_manager = CheckpointEngineManager(
- backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend,
+ config=omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine),
trainer=self.actor_rollout_wg,
replicas=self.async_rollout_manager.rollout_replicas,
)
@@ -404,15 +405,12 @@ def _fit_generate(self, batch: DataProto = None) -> DataProto:
gen_batch_output = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
with marked_timer("gen", timing_raw, color="red"):
- if not self.async_rollout_mode:
- gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
- else:
- if self.curr_step_profile:
- self.async_rollout_manager.start_profile(global_step=self.global_steps)
- gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
- self.checkpoint_manager.sleep_replicas()
- if self.curr_step_profile:
- self.async_rollout_manager.stop_profile()
+ if self.curr_step_profile:
+ self.async_rollout_manager.start_profile(global_step=self.global_steps)
+ gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
+ self.checkpoint_manager.sleep_replicas()
+ if self.curr_step_profile:
+ self.async_rollout_manager.stop_profile()
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
@@ -421,15 +419,12 @@ def _fit_generate(self, batch: DataProto = None) -> DataProto:
with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
- if not self.async_rollout_mode:
- gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
- else:
- if self.curr_step_profile:
- self.async_rollout_manager.start_profile()
- gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
- self.checkpoint_manager.sleep_replicas()
- if self.curr_step_profile:
- self.async_rollout_manager.stop_profile()
+ if self.curr_step_profile:
+ self.async_rollout_manager.start_profile()
+ gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
+ self.checkpoint_manager.sleep_replicas()
+ if self.curr_step_profile:
+ self.async_rollout_manager.stop_profile()
batch = batch.union(gen_baseline_output)
# compute reward model score on batch
rm_scores = None
diff --git a/verl/experimental/transfer_queue/agent_loop.py b/verl/experimental/transfer_queue/agent_loop.py
deleted file mode 100644
index 308be43ebc0..00000000000
--- a/verl/experimental/transfer_queue/agent_loop.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# Copyright 2025 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import asyncio
-
-import numpy as np
-import ray
-from transfer_queue import BatchMeta
-
-import verl.experimental.agent_loop.agent_loop as agent_loop
-
-
-class AgentLoopManager(agent_loop.AgentLoopManager):
- def generate_sequences(self, prompts: BatchMeta) -> BatchMeta:
- """Split input batch and dispatch to agent loop workers.
-
- Args:
- prompts (BatchMeta): Input batch.
-
- Returns:
- BatchMeta: Output batch metadata.
- """
-
- chunkes = prompts.chunk(len(self.agent_loop_workers))
- outputs = ray.get(
- [
- worker.generate_sequences.remote(chunk)
- for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)
- ]
- )
- output = BatchMeta.concat(outputs)
-
- # calculate performance metrics
- metrics = [output.extra_info.pop("metrics") for output in outputs] # List[List[Dict[str, str]]]
- timing = self._performance_metrics(metrics, output)
-
- output.set_extra_info("timing", timing)
- return output
-
- def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: BatchMeta) -> dict[str, float]:
- timing = {}
- t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk])
- t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk])
- timing["agent_loop/generate_sequences/min"] = t_generate_sequences.min()
- timing["agent_loop/generate_sequences/max"] = t_generate_sequences.max()
- timing["agent_loop/generate_sequences/mean"] = t_generate_sequences.mean()
- timing["agent_loop/tool_calls/min"] = t_tool_calls.min()
- timing["agent_loop/tool_calls/max"] = t_tool_calls.max()
- timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean()
-
- # TODO (TQ): initialize tq during init when enable TQ switch is stable
- tq_client = self._create_transferqueue_client()
- # batch sequence generation is bounded by the slowest sample
- slowest = np.argmax(t_generate_sequences + t_tool_calls)
- attention_mask = asyncio.run(tq_client.async_get_data(output[slowest]))["attention_mask"]
- prompt_length = output.samples[0].fields["prompts"].shape[0]
- timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest]
- timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest]
- timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item()
- timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item()
-
- return timing
-
- def create_transferqueue_client_for_workers(self):
- # TODO (TQ): initialize tq during worker init when enable TQ switch is stable
- ray.get([worker.create_transferqueue_client.remote() for worker in self.agent_loop_workers])
-
- def _create_transferqueue_client(self):
- """Create a client for data system (TransferQueue)."""
- from verl.single_controller.ray.base import get_random_string
- from verl.utils.transferqueue_utils import create_transferqueue_client
-
- client_name = get_random_string(length=6)
-
- tq_client = create_transferqueue_client(
- client_id=f"AgentLoopManager_{client_name}",
- config=self.config.transfer_queue,
- )
-
- return tq_client
diff --git a/verl/experimental/transfer_queue/config/transfer_queue_ppo_megatron_trainer.yaml b/verl/experimental/transfer_queue/config/transfer_queue_ppo_megatron_trainer.yaml
deleted file mode 100644
index 37b19b45708..00000000000
--- a/verl/experimental/transfer_queue/config/transfer_queue_ppo_megatron_trainer.yaml
+++ /dev/null
@@ -1,14 +0,0 @@
-hydra:
- searchpath:
- - file://verl/trainer/config
-
-defaults:
- - ppo_megatron_trainer
- - _self_
-
-# config for TransferQueue
-transfer_queue:
- enable: True
- num_global_batch: 1
- storage_backend: AsyncSimpleStorageManager
- num_data_storage_units: 8
\ No newline at end of file
diff --git a/verl/experimental/transfer_queue/config/transfer_queue_ppo_trainer.yaml b/verl/experimental/transfer_queue/config/transfer_queue_ppo_trainer.yaml
deleted file mode 100644
index 7a5f57ddd4f..00000000000
--- a/verl/experimental/transfer_queue/config/transfer_queue_ppo_trainer.yaml
+++ /dev/null
@@ -1,14 +0,0 @@
-hydra:
- searchpath:
- - file://verl/trainer/config
-
-defaults:
- - ppo_trainer
- - _self_
-
-# config for TransferQueue
-transfer_queue:
- enable: True
- num_global_batch: 1
- storage_backend: AsyncSimpleStorageManager
- num_data_storage_units: 8
diff --git a/verl/experimental/transfer_queue/main_ppo.py b/verl/experimental/transfer_queue/main_ppo.py
deleted file mode 100644
index 075a8f7ace6..00000000000
--- a/verl/experimental/transfer_queue/main_ppo.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
-"""
-
-import os
-import socket
-
-import hydra
-import ray
-from omegaconf import OmegaConf
-
-from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
-from verl.trainer.main_ppo import TaskRunner as MainTaskRunner
-from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
-from verl.trainer.ppo.utils import need_critic, need_reference_policy
-from verl.utils.config import validate_config
-from verl.utils.device import auto_set_device, is_cuda_available
-
-from .ray_trainer import RayPPOTrainer
-
-
-@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
-def main(config):
- """Main entry point for PPO training with Hydra configuration management.
-
- Args:
- config_dict: Hydra configuration dictionary containing training parameters.
- """
- # Automatically set `config.trainer.device = npu` when running on Ascend NPU.
- auto_set_device(config)
-
- run_ppo(config)
-
-
-# Define a function to run the PPO-like training process
-def run_ppo(config, task_runner_class=None) -> None:
- """Initialize Ray cluster and run distributed PPO training process.
-
- Args:
- config: Training configuration object containing all necessary parameters
- for distributed PPO training including Ray initialization settings,
- model paths, and training hyperparameters.
- task_runner_class: For recipe to change TaskRunner.
- """
- # Check if Ray is not initialized
- if not ray.is_initialized():
- # Initialize Ray with a local cluster configuration
- # Set environment variables in the runtime environment to control tokenizer parallelism,
- # NCCL debug level, VLLM logging level, and allow runtime LoRA updating
- # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
- default_runtime_env = get_ppo_ray_runtime_env()
- ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
- runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
-
- if config.transfer_queue.enable:
- # Add runtime environment variables for transfer queue
- runtime_env_vars = runtime_env_kwargs.get("env_vars", {})
- runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1"
- runtime_env_kwargs["env_vars"] = runtime_env_vars
-
- runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
- ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
- print(f"ray init kwargs: {ray_init_kwargs}")
- ray.init(**OmegaConf.to_container(ray_init_kwargs))
-
- if task_runner_class is None:
- task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head
-
- # Create a remote instance of the TaskRunner class, and
- # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
- if (
- is_cuda_available
- and config.global_profiler.tool == "nsys"
- and config.global_profiler.get("steps") is not None
- and len(config.global_profiler.get("steps", [])) > 0
- ):
- from verl.utils.import_utils import is_nvtx_available
-
- assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'"
- nsight_options = OmegaConf.to_container(
- config.global_profiler.global_tool_config.nsys.controller_nsight_options
- )
- runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote()
- else:
- runner = task_runner_class.remote()
- ray.get(runner.run.remote(config))
-
- # [Optional] get the path of the timeline trace file from the configuration, default to None
- # This file is used for performance analysis
- timeline_json_file = config.ray_kwargs.get("timeline_json_file", None)
- if timeline_json_file:
- ray.timeline(filename=timeline_json_file)
-
-
-class TaskRunner(MainTaskRunner):
- def run(self, config):
- """Execute the main PPO training workflow.
-
- This method sets up the distributed training environment, initializes
- workers, datasets, and reward functions, then starts the training process.
-
- Args:
- config: Training configuration object containing all parameters needed
- for setting up and running the PPO training process.
- """
- # Print the initial configuration. `resolve=True` will evaluate symbolic values.
- from pprint import pprint
-
- from verl.utils.fs import copy_to_local
-
- print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
- pprint(OmegaConf.to_container(config, resolve=True))
- OmegaConf.resolve(config)
-
- actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)
- self.add_critic_worker(config)
-
- # We should adopt a multi-source reward function here:
- # - for rule-based rm, we directly call a reward score
- # - for model-based rm, we call a model
- # - for code related prompt, we send to a sandbox if there are test cases
- # finally, we combine all the rewards together
- # The reward type depends on the tag of the data
- self.add_reward_model_resource_pool(config)
-
- # Add a reference policy worker if KL loss or KL reward is used.
- self.add_ref_policy_worker(config, actor_rollout_cls)
-
- # validate config
- validate_config(
- config=config,
- use_reference_policy=need_reference_policy(config),
- use_critic=need_critic(config),
- )
-
- # Download the checkpoint from HDFS to the local machine.
- # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
- local_path = copy_to_local(
- config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
- )
-
- # Instantiate the tokenizer and processor.
- from verl.utils import hf_processor, hf_tokenizer
-
- trust_remote_code = config.data.get("trust_remote_code", False)
- tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
- # Used for multimodal LLM, could be None
- processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
-
- resource_pool_manager = self.init_resource_pool_mgr(config)
-
- from verl.utils.dataset.rl_dataset import collate_fn
-
- # Create training and validation datasets.
- train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True)
- val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False)
- train_sampler = create_rl_sampler(config.data, train_dataset)
-
- # Initialize the PPO trainer.
- trainer = RayPPOTrainer(
- config=config,
- tokenizer=tokenizer,
- processor=processor,
- role_worker_mapping=self.role_worker_mapping,
- resource_pool_manager=resource_pool_manager,
- ray_worker_group_cls=ray_worker_group_cls,
- train_dataset=train_dataset,
- val_dataset=val_dataset,
- collate_fn=collate_fn,
- train_sampler=train_sampler,
- )
- # Initialize the workers of the trainer.
- trainer.init_workers()
- # Start the training process.
- trainer.fit()
-
-
-if __name__ == "__main__":
- main()
diff --git a/verl/experimental/transfer_queue/ray_trainer.py b/verl/experimental/transfer_queue/ray_trainer.py
deleted file mode 100644
index 3d7c0f58d30..00000000000
--- a/verl/experimental/transfer_queue/ray_trainer.py
+++ /dev/null
@@ -1,1614 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-# Copyright 2023-2024 SGLang Team
-# Copyright 2025 ModelBest Inc. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-PPO Trainer with Ray-based single controller.
-This trainer supports model-agonistic model initialization with huggingface
-"""
-
-import json
-import logging
-import math
-import os
-import uuid
-from collections import defaultdict
-from pprint import pprint
-from typing import Any, Optional
-
-import numpy as np
-import tensordict
-import torch
-from omegaconf import OmegaConf, open_dict
-from packaging.version import parse as parse_version
-from tensordict import TensorDict
-from torch.utils.data import Dataset, Sampler
-from torchdata.stateful_dataloader import StatefulDataLoader
-from tqdm import tqdm
-from transfer_queue import (
- BatchMeta,
- SimpleStorageUnit,
- TransferQueueController,
- get_placement_group,
- process_zmq_server_info,
-)
-
-from verl import DataProto
-from verl.checkpoint_engine import CheckpointEngineManager
-from verl.experimental.dataset.sampler import AbstractCurriculumSampler
-from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager
-from verl.single_controller.ray.base import create_colocated_worker_cls
-from verl.trainer.config import AlgoConfig
-from verl.trainer.ppo import core_algos
-from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
-from verl.trainer.ppo.metric_utils import (
- compute_data_metrics,
- compute_throughout_metrics,
- compute_timing_metrics,
- process_validation_metrics,
-)
-from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
-from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
-from verl.utils.config import omega_conf_to_dataclass
-from verl.utils.debug import marked_timer
-from verl.utils.metric import reduce_metrics
-from verl.utils.rollout_skip import RolloutSkip
-from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance
-from verl.utils.torch_functional import masked_mean
-from verl.utils.tracking import ValidationGenerationsLogger
-from verl.utils.transferqueue_utils import create_transferqueue_client, get_transferqueue_client, tqbridge
-
-
-@tqbridge(put_data=False)
-def compute_reward_decorated(data):
- reward_tensor = data.batch["rm_scores"]
- reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
- reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
- return reward_tensor, reward_extra_info
-
-
-@tqbridge(put_data=False)
-def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"):
- """Apply KL penalty to the token-level rewards.
-
- This function computes the KL divergence between the reference policy and current policy,
- then applies a penalty to the token-level rewards based on this divergence.
-
- Args:
- data (DataProto): The data containing batched model outputs and inputs.
- kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty.
- kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl".
-
- Returns:
- tuple: A tuple containing:
- - The updated data with token-level rewards adjusted by KL penalty
- - A dictionary of metrics related to the KL penalty
- """
- response_mask = data.batch["response_mask"]
- token_level_scores = data.batch["token_level_scores"]
- batch_size = data.batch.batch_size[0]
-
- # compute kl between ref_policy and current policy
- # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.
- kld = core_algos.kl_penalty(
- data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty
- ) # (batch_size, response_length)
- kld = kld * response_mask
- beta = kl_ctrl.value
-
- token_level_rewards = token_level_scores - beta * kld
-
- current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
- current_kl = torch.mean(current_kl, dim=0).item()
-
- # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
- kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
-
- metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta}
-
- return token_level_rewards, metrics
-
-
-def compute_response_mask(batch_meta: BatchMeta, tq_client):
- """Compute the attention mask for the response part of the sequence.
-
- This function extracts the portion of the attention mask that corresponds to the model's response,
- which is used for masking computations that should only apply to response tokens.
-
- Args:
- batch_meta (BatchMeta): The data containing batched model outputs and inputs.
-
- Returns:
- BatchMeta: The BatchMeta of attention mask for the response tokens.
- """
- data = tq_client.get_data(batch_meta)
-
- responses = data["responses"]
- response_length = responses.size(1)
- attention_mask = data["attention_mask"]
- response_mask = attention_mask[:, -response_length:]
- output = TensorDict({"response_mask": response_mask}, batch_size=response_mask.size(0))
-
- batch_meta = tq_client.put(data=output, metadata=batch_meta)
-
- return batch_meta
-
-
-@tqbridge(put_data=False)
-def compute_advantage(
- data: DataProto,
- adv_estimator: AdvantageEstimator,
- gamma: float = 1.0,
- lam: float = 1.0,
- num_repeat: int = 1,
- norm_adv_by_std_in_grpo: bool = True,
- config: Optional[AlgoConfig] = None,
-) -> tuple[Any, Any]:
- """Compute advantage estimates for policy optimization.
-
- This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc.
- The advantage estimates are used to guide policy optimization in RL algorithms.
-
- Args:
- data (DataProto): The data containing batched model outputs and inputs.
- adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++).
- gamma (float, optional): Discount factor for future rewards. Defaults to 1.0.
- lam (float, optional): Lambda parameter for GAE. Defaults to 1.0.
- num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1.
- norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in
- GRPO. Defaults to True.
- config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None.
-
- Returns:
- tuple: A tuple containing:
- - advantages: The computed advantage estimates.
- - returns: The computed returns.
- """
- # prepare response group
- if adv_estimator == AdvantageEstimator.GAE:
- # Compute advantages and returns using Generalized Advantage Estimation (GAE)
- advantages, returns = core_algos.compute_gae_advantage_return(
- token_level_rewards=data.batch["token_level_rewards"],
- values=data.batch["values"],
- response_mask=data.batch["response_mask"],
- gamma=gamma,
- lam=lam,
- )
- # TODO (TQ): adapt core_algos.compute_pf_ppo_reweight_data function to support transfer queue
- if config.get("use_pf_ppo", False):
- data = core_algos.compute_pf_ppo_reweight_data(
- data,
- config.pf_ppo.get("reweight_method"),
- config.pf_ppo.get("weight_pow"),
- )
- elif adv_estimator == AdvantageEstimator.GRPO:
- # Initialize the mask for GRPO calculation
- grpo_calculation_mask = data.batch["response_mask"]
- # Call compute_grpo_outcome_advantage with parameters matching its definition
- advantages, returns = core_algos.compute_grpo_outcome_advantage(
- token_level_rewards=data.batch["token_level_rewards"],
- response_mask=grpo_calculation_mask,
- index=data.non_tensor_batch["uid"],
- norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
- )
- else:
- # handle all other adv estimator type other than GAE and GRPO
- adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)
- adv_kwargs = {
- "token_level_rewards": data.batch["token_level_rewards"],
- "response_mask": data.batch["response_mask"],
- "config": config,
- }
- if "uid" in data.non_tensor_batch: # optional
- adv_kwargs["index"] = data.non_tensor_batch["uid"]
- if "reward_baselines" in data.batch: # optional
- adv_kwargs["reward_baselines"] = data.batch["reward_baselines"]
-
- # calculate advantage estimator
- advantages, returns = adv_estimator_fn(**adv_kwargs)
- return advantages, returns
-
-
-@tqbridge(put_data=False)
-def compute_data_metrics_decorated(batch, use_critic: bool = True):
- return compute_data_metrics(batch, use_critic)
-
-
-@tqbridge(put_data=False)
-def compute_timing_metrics_decorated(batch, timing_raw: dict[str, float]) -> dict[str, Any]:
- return compute_timing_metrics(batch, timing_raw)
-
-
-@tqbridge(put_data=False)
-def compute_throughout_metrics_decorated(batch, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]:
- return compute_throughout_metrics(batch, timing_raw, n_gpus)
-
-
-@tqbridge(put_data=False)
-def calculate_debug_metrics_decorated(data):
- from verl.utils.debug.metrics import calculate_debug_metrics
-
- return calculate_debug_metrics(data)
-
-
-class RayPPOTrainer:
- """Distributed PPO trainer using Ray for scalable reinforcement learning.
-
- This trainer orchestrates distributed PPO training across multiple nodes and GPUs,
- managing actor rollouts, critic training, and reward computation with Ray backend.
- Supports various model architectures including FSDP, Megatron, vLLM, and SGLang integration.
- """
-
- # TODO: support each role have individual ray_worker_group_cls,
- # i.e., support different backend of different role
- def __init__(
- self,
- config,
- tokenizer,
- role_worker_mapping: dict[Role, WorkerType],
- resource_pool_manager: ResourcePoolManager,
- ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup,
- processor=None,
- train_dataset: Optional[Dataset] = None,
- val_dataset: Optional[Dataset] = None,
- collate_fn=None,
- train_sampler: Optional[Sampler] = None,
- device_name=None,
- ):
- """
- Initialize distributed PPO trainer with Ray backend.
- Note that this trainer runs on the driver process on a single CPU/GPU node.
-
- Args:
- config: Configuration object containing training parameters.
- tokenizer: Tokenizer used for encoding and decoding text.
- role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes.
- resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools.
- ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup.
- processor: Optional data processor, used for multimodal data
- train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.
- val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.
- collate_fn: Function to collate data samples into batches.
- train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.
- device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None.
- """
-
- # Store the tokenizer for text processing
- self.tokenizer = tokenizer
- self.processor = processor
- self.config = config
-
- self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
- assert self.hybrid_engine, "Currently, only support hybrid engine"
-
- if self.hybrid_engine:
- assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"
-
- self.role_worker_mapping = role_worker_mapping
- self.resource_pool_manager = resource_pool_manager
- self.use_reference_policy = need_reference_policy(self.config)
- self.use_rm = need_reward_model(self.config)
- self.use_critic = need_critic(self.config)
- self.ray_worker_group_cls = ray_worker_group_cls
- self.device_name = device_name if device_name else self.config.trainer.device
- self.validation_generations_logger = ValidationGenerationsLogger(
- project_name=self.config.trainer.project_name,
- experiment_name=self.config.trainer.experiment_name,
- )
-
- lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
- if lora_rank <= 0:
- lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
- # if ref_in_actor is True, the reference policy will be actor without lora applied
- self.ref_in_actor = lora_rank > 0
-
- # define in-reward KL control
- # kl loss control currently not suppoorted
- if self.config.algorithm.use_kl_in_reward:
- self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
-
- self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
-
- self.tq_client = self._initialize_transferqueue()
-
- def _initialize_transferqueue(self):
- # 1. initialize TransferQueueStorage
- if self.config.transfer_queue.storage_backend == "AsyncSimpleStorageManager":
- train_data_size = (
- self.config.data.train_batch_size
- * self.config.transfer_queue.num_global_batch
- * self.config.actor_rollout_ref.rollout.n
- )
- val_data_size = self.val_dataset_size * self.config.actor_rollout_ref.rollout.val_kwargs.n
-
- total_storage_size = train_data_size + val_data_size
- self.data_system_storage_units = {}
- storage_placement_group = get_placement_group(
- self.config.transfer_queue.num_data_storage_units, num_cpus_per_actor=1
- )
- for storage_unit_rank in range(self.config.transfer_queue.num_data_storage_units):
- storage_node = SimpleStorageUnit.options(
- placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
- ).remote(
- storage_unit_size=math.ceil(total_storage_size / self.config.transfer_queue.num_data_storage_units)
- )
- self.data_system_storage_units[storage_unit_rank] = storage_node
- logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
- else:
- raise NotImplementedError("Currently only support AsyncSimpleStorageManager backend in TransferQueue")
-
- # 2. Initialize TransferQueueController (single controller only)
-
- # Sampler usage instructions:
- # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler:
- # Option 1: Pass sampler class (will be instantiated automatically)
- # self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler)
-
- # Option 2: Pass sampler instance (if you need custom configuration)
- # grpo_sampler = GRPOGroupNSampler()
- # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)
-
- # Then use sampling_config in get_meta calls:
- # sampling_config={"n_samples_per_prompt": 4}
- self.data_system_controller = TransferQueueController.remote()
- logging.info("TransferQueueController has been created.")
-
- # 3. register controller & storage and prepare necessary information
- self.data_system_controller_info = process_zmq_server_info(self.data_system_controller)
- if self.config.transfer_queue.storage_backend == "AsyncSimpleStorageManager":
- self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)
-
- # Note: Need to generate a new DictConfig with allow_objects=True to preserve ZMQServerInfo instances
- # (which contain socket connection details). Without this flag, OmegaConf would flatten these objects to dicts,
- # breaking the transfer queue client initialization.
- tq_config = OmegaConf.create({"transfer_queue": {}}, flags={"allow_objects": True})
- tq_config.transfer_queue.controller_info = self.data_system_controller_info
-
- if self.config.transfer_queue.storage_backend == "AsyncSimpleStorageManager":
- tq_config.transfer_queue.storage_unit_infos = self.data_system_storage_unit_infos
-
- self.config = OmegaConf.merge(tq_config, self.config)
-
- # 4. create client
- create_transferqueue_client(client_id="Trainer", config=self.config.transfer_queue, sync=True)
- tq_client = get_transferqueue_client()
- return tq_client
-
- def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
- """
- Creates the train and validation dataloaders.
- """
- # TODO: we have to make sure the batch size is divisible by the dp size
- from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
-
- if train_dataset is None:
- train_dataset = create_rl_dataset(
- self.config.data.train_files, self.config.data, self.tokenizer, self.processor
- )
- if val_dataset is None:
- val_dataset = create_rl_dataset(
- self.config.data.val_files, self.config.data, self.tokenizer, self.processor
- )
- self.train_dataset, self.val_dataset = train_dataset, val_dataset
-
- self.val_dataset_size = len(val_dataset)
-
- if train_sampler is None:
- train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
- if collate_fn is None:
- from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
-
- collate_fn = default_collate_fn
-
- num_workers = self.config.data["dataloader_num_workers"]
-
- self.train_dataloader = StatefulDataLoader(
- dataset=self.train_dataset,
- batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
- num_workers=num_workers,
- drop_last=True,
- collate_fn=collate_fn,
- sampler=train_sampler,
- )
-
- val_batch_size = self.config.data.val_batch_size # Prefer config value if set
- if val_batch_size is None:
- val_batch_size = len(self.val_dataset)
- self.val_batch_size = val_batch_size
-
- self.val_dataloader = StatefulDataLoader(
- dataset=self.val_dataset,
- batch_size=val_batch_size,
- num_workers=num_workers,
- shuffle=self.config.data.get("validation_shuffle", True),
- drop_last=False,
- collate_fn=collate_fn,
- )
-
- assert len(self.train_dataloader) >= 1, "Train dataloader is empty!"
- assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"
-
- print(
- f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: "
- f"{len(self.val_dataloader)}"
- )
-
- total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
-
- if self.config.trainer.total_training_steps is not None:
- total_training_steps = self.config.trainer.total_training_steps
-
- self.total_training_steps = total_training_steps
- print(f"Total training steps: {self.total_training_steps}")
-
- try:
- OmegaConf.set_struct(self.config, True)
- with open_dict(self.config):
- if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"):
- self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
- if OmegaConf.select(self.config, "critic.optim"):
- self.config.critic.optim.total_training_steps = total_training_steps
- except Exception as e:
- print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
-
- def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path):
- """Dump rollout/validation samples as JSONL."""
- os.makedirs(dump_path, exist_ok=True)
- filename = os.path.join(dump_path, f"{self.global_steps}.jsonl")
-
- n = len(inputs)
- base_data = {
- "input": inputs,
- "output": outputs,
- "gts": gts,
- "score": scores,
- "step": [self.global_steps] * n,
- }
-
- for k, v in reward_extra_infos_dict.items():
- if len(v) == n:
- base_data[k] = v
-
- lines = []
- for i in range(n):
- entry = {k: v[i] for k, v in base_data.items()}
- lines.append(json.dumps(entry, ensure_ascii=False))
-
- with open(filename, "w") as f:
- f.write("\n".join(lines) + "\n")
-
- print(f"Dumped generations to {filename}")
-
- def _log_rollout_data(
- self, log_rollout_meta: BatchMeta, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str
- ):
- """
- Log rollout data to disk.
-
- Args:
- log_rollout_meta (BatchMeta): The batch_meta of rollout data
- reward_extra_infos_dict (dict): Additional reward information to log
- timing_raw (dict): Timing information for profiling
- rollout_data_dir (str): Directory path to save the rollout data
- """
- with marked_timer("dump_rollout_generations", timing_raw, color="green"):
- data = self.tq_client.get_data(log_rollout_meta)
-
- inputs = self.tokenizer.batch_decode(data["prompts"], skip_special_tokens=True)
- outputs = self.tokenizer.batch_decode(data["responses"], skip_special_tokens=True)
- scores = data["token_level_scores"].sum(-1).cpu().tolist()
- sample_gts = [item.get("ground_truth", None) for item in data.get("reward_model", {})]
-
- reward_extra_infos_to_dump = reward_extra_infos_dict.copy()
- if "request_id" in log_rollout_meta.field_names:
- reward_extra_infos_dict.setdefault(
- "request_id",
- data["request_id"].tolist(),
- )
-
- self._dump_generations(
- inputs=inputs,
- outputs=outputs,
- gts=sample_gts,
- scores=scores,
- reward_extra_infos_dict=reward_extra_infos_to_dump,
- dump_path=rollout_data_dir,
- )
-
- def _maybe_log_val_generations(self, inputs, outputs, scores):
- """Log a table of validation samples to the configured logger (wandb or swanlab)"""
-
- generations_to_log = self.config.trainer.log_val_generations
-
- if generations_to_log == 0:
- return
-
- import numpy as np
-
- # Create tuples of (input, output, score) and sort by input text
- samples = list(zip(inputs, outputs, scores, strict=True))
- samples.sort(key=lambda x: x[0]) # Sort by input text
-
- # Use fixed random seed for deterministic shuffling
- rng = np.random.RandomState(42)
- rng.shuffle(samples)
-
- # Take first N samples after shuffling
- samples = samples[:generations_to_log]
-
- # Log to each configured logger
- self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)
-
- def _get_gen_batch(self, batch: DataProto) -> DataProto:
- reward_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys()
-
- # pop those keys for generation
- batch_keys_to_pop = []
- non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_keys
- gen_batch = batch.pop(
- batch_keys=batch_keys_to_pop,
- non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop),
- )
-
- # For agent loop, we need reward model keys to compute score.
- if self.async_rollout_mode:
- gen_batch.non_tensor_batch.update(batch.non_tensor_batch)
-
- return gen_batch
-
- def _validate(self):
- data_source_lst = []
- reward_extra_infos_dict: dict[str, list] = defaultdict(list)
-
- # Lists to collect samples for the table
- sample_inputs = []
- sample_outputs = []
- sample_gts = []
- sample_scores = []
- sample_turns = []
- sample_uids = []
-
- for test_data in self.val_dataloader:
- if "uid" not in test_data.keys():
- test_data["uid"] = np.array(
- [str(uuid.uuid4()) for _ in range(len(test_data["raw_prompt"]))], dtype=object
- )
-
- # repeat test data
- repeated_test_data = self.repeat_dict(
- test_data, repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
- )
-
- test_batch: TensorDict = self.dict_to_tensordict(repeated_test_data)
-
- # we only do validation on rule-based rm
- if self.config.reward.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model":
- return {}
-
- batch_meta = self.tq_client.put(data=test_batch, partition_id=f"val_{self.global_steps - 1}")
-
- batch_meta.update_extra_info(
- {
- "eos_token_id": self.tokenizer.eos_token_id,
- "pad_token_id": self.tokenizer.pad_token_id,
- "recompute_log_prob": False,
- "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
- "validate": True,
- "global_steps": self.global_steps,
- }
- )
- print(f"batch_meta extra_info: {batch_meta.extra_info}")
-
- # TODO: (TQ) Support padding and unpadding to make DataProto divisible by dp_size with TransferQueue
- if not self.async_rollout_mode:
- test_output_gen_meta = self.actor_rollout_wg.generate_sequences(batch_meta)
- else:
- test_output_gen_meta = self.async_rollout_manager.generate_sequences(batch_meta)
-
- batch_meta = batch_meta.union(test_output_gen_meta)
-
- print("validation generation end")
-
- # Store generated outputs
- test_response_meta = batch_meta.select_fields(["prompts", "responses", "uid", "reward_model"])
- data = self.tq_client.get_data(test_response_meta)
- output_ids = data["responses"]
- output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
- sample_outputs.extend(output_texts)
-
- # TODO: Can we keep special tokens except for padding tokens?
- input_ids = data["prompts"]
- input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
- sample_inputs.extend(input_texts)
- sample_uids.extend(data["uid"])
-
- ground_truths = [item.get("ground_truth", None) for item in data.get("reward_model", {})]
- sample_gts.extend(ground_truths)
-
- reward_tensor, reward_extra_info = compute_reward_decorated(batch_meta)
-
- scores = reward_tensor.sum(-1).cpu().tolist()
- sample_scores.extend(scores)
-
- reward_extra_infos_dict["reward"].extend(scores)
- print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}")
- for key, lst in reward_extra_info.items():
- reward_extra_infos_dict[key].extend(lst)
- print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}")
-
- # collect num_turns of each prompt
- if "__num_turns__" in batch_meta.field_names:
- data = self.tq_client.get_data(batch_meta.select_fields(["__num_turns__"]))
- sample_turns.append(data["__num_turns__"])
-
- data_source = ["unknown"] * reward_tensor.shape[0]
- if "data_source" in batch_meta.field_names:
- data_source_meta = batch_meta.select_fields(["data_source"])
- data = self.tq_client.get_data(data_source_meta)
- data_source = data["data_source"]
-
- data_source_lst.append(data_source)
-
- self.tq_client.clear_samples(batch_meta)
-
- self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
-
- # dump generations
- val_data_dir = self.config.trainer.get("validation_data_dir", None)
- if val_data_dir:
- self._dump_generations(
- inputs=sample_inputs,
- outputs=sample_outputs,
- gts=sample_gts,
- scores=sample_scores,
- reward_extra_infos_dict=reward_extra_infos_dict,
- dump_path=val_data_dir,
- )
-
- for key_info, lst in reward_extra_infos_dict.items():
- assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"
-
- data_sources = np.concatenate(data_source_lst, axis=0)
-
- data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)
- metric_dict = {}
- for data_source, var2metric2val in data_src2var2metric2val.items():
- core_var = "acc" if "acc" in var2metric2val else "reward"
- for var_name, metric2val in var2metric2val.items():
- n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
- for metric_name, metric_val in metric2val.items():
- if (
- (var_name == core_var)
- and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
- and (f"@{n_max}" in metric_name)
- ):
- metric_sec = "val-core"
- else:
- metric_sec = "val-aux"
- pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
- metric_dict[pfx] = metric_val
-
- if len(sample_turns) > 0:
- sample_turns = np.concatenate(sample_turns)
- metric_dict["val-aux/num_turns/min"] = sample_turns.min()
- metric_dict["val-aux/num_turns/max"] = sample_turns.max()
- metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()
-
- return metric_dict
-
- def init_workers(self):
- """Initialize distributed training workers using Ray backend.
-
- Creates:
- 1. Ray resource pools from configuration
- 2. Worker groups for each role (actor, critic, etc.)
- """
- self.resource_pool_manager.create_resource_pool()
-
- self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
-
- # create actor and rollout
- if self.hybrid_engine:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
- actor_rollout_cls = RayClassWithInitArgs(
- cls=self.role_worker_mapping[Role.ActorRollout],
- config=self.config.actor_rollout_ref,
- role="actor_rollout",
- )
- self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
- else:
- raise NotImplementedError
-
- # create critic
- if self.use_critic:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
- critic_cfg = omega_conf_to_dataclass(self.config.critic)
- critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
- self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
-
- # create reference policy if needed
- if self.use_reference_policy:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
- ref_policy_cls = RayClassWithInitArgs(
- self.role_worker_mapping[Role.RefPolicy],
- config=self.config.actor_rollout_ref,
- role="ref",
- )
- self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
-
- # initialize WorkerGroup
- # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
- # you should not use `create_colocated_worker_cls`.
- # Instead, directly pass different resource pool to different worker groups.
- # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
- all_wg = {}
- wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
- if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
- wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
- if OmegaConf.select(self.config.global_profiler, "steps") is not None:
- wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
- # Only require nsight worker options when tool is nsys
- if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
- assert (
- OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
- is not None
- ), "worker_nsight_options must be set when using nsys with profile_steps"
- wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
- OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
- )
- wg_kwargs["device_name"] = self.device_name
-
- for resource_pool, class_dict in self.resource_pool_to_cls.items():
- worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
- wg_dict = self.ray_worker_group_cls(
- resource_pool=resource_pool,
- ray_cls_with_init=worker_dict_cls,
- **wg_kwargs,
- )
- spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
- all_wg.update(spawn_wg)
-
- if self.use_critic:
- self.critic_wg = all_wg["critic"]
- self.critic_wg.init_model()
-
- if self.use_reference_policy and not self.ref_in_actor:
- self.ref_policy_wg = all_wg["ref"]
- self.ref_policy_wg.init_model()
-
- # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
- self.actor_rollout_wg = all_wg["actor_rollout"]
- self.actor_rollout_wg.init_model()
-
- # set transferqueue server info for each worker
- for _, wg in all_wg.items():
- wg.create_transferqueue_client(self.config)
-
- # create reward loop manager
- from verl.experimental.reward_loop import RewardLoopManager
-
- # initalize reward loop manager
- # reward model (colocate or standalone): get resource_pool
- # no reward model: resource_pool = None
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) if self.use_rm else None
- self.reward_loop_manager = RewardLoopManager(
- config=self.config,
- rm_resource_pool=resource_pool,
- )
-
- # create async rollout manager and request scheduler
- self.async_rollout_mode = False
- if self.config.actor_rollout_ref.rollout.mode == "async":
- from .agent_loop import AgentLoopManager
-
- self.async_rollout_mode = True
-
- enable_agent_reward_loop = not self.use_rm or self.config.reward.reward_model.enable_resource_pool
-
- reward_loop_worker_handles = (
- self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None
- )
- self.async_rollout_manager = AgentLoopManager(
- config=self.config,
- worker_group=self.actor_rollout_wg,
- reward_loop_worker_handles=reward_loop_worker_handles,
- )
-
- self.checkpoint_manager = CheckpointEngineManager(
- backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend,
- trainer=self.actor_rollout_wg,
- replicas=self.async_rollout_manager.rollout_replicas,
- )
-
- # sleep all replicas to load checkpoint
- self.checkpoint_manager.sleep_replicas()
-
- # TODO (TQ): initialize tq during worker init when enable TQ switch is stable
- self.async_rollout_manager.create_transferqueue_client_for_workers()
-
- def _save_checkpoint(self):
- from verl.utils.fs import local_mkdir_safe
-
- # path: given_path + `/global_step_{global_steps}` + `/actor`
- local_global_step_folder = os.path.join(
- self.config.trainer.default_local_dir, f"global_step_{self.global_steps}"
- )
-
- print(f"local_global_step_folder: {local_global_step_folder}")
- actor_local_path = os.path.join(local_global_step_folder, "actor")
-
- actor_remote_path = (
- None
- if self.config.trainer.default_hdfs_dir is None
- else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor")
- )
-
- remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False)
- if remove_previous_ckpt_in_save:
- print(
- "Warning: remove_previous_ckpt_in_save is deprecated,"
- + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead"
- )
- max_actor_ckpt_to_keep = (
- self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
- )
- max_critic_ckpt_to_keep = (
- self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
- )
-
- self.actor_rollout_wg.save_checkpoint(
- actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep
- )
-
- if self.use_critic:
- critic_local_path = os.path.join(local_global_step_folder, "critic")
- critic_remote_path = (
- None
- if self.config.trainer.default_hdfs_dir is None
- else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic")
- )
- self.critic_wg.save_checkpoint(
- critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep
- )
-
- # save dataloader
- local_mkdir_safe(local_global_step_folder)
- dataloader_local_path = os.path.join(local_global_step_folder, "data.pt")
- dataloader_state_dict = self.train_dataloader.state_dict()
- torch.save(dataloader_state_dict, dataloader_local_path)
-
- # latest checkpointed iteration tracker (for atomic usage)
- local_latest_checkpointed_iteration = os.path.join(
- self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt"
- )
- with open(local_latest_checkpointed_iteration, "w") as f:
- f.write(str(self.global_steps))
-
- def _load_checkpoint(self):
- if self.config.trainer.resume_mode == "disable":
- return 0
-
- # load from hdfs
- if self.config.trainer.default_hdfs_dir is not None:
- raise NotImplementedError("load from hdfs is not implemented yet")
- else:
- checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
- if not os.path.isabs(checkpoint_folder):
- working_dir = os.getcwd()
- checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
- global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest
-
- # find global_step_folder
- if self.config.trainer.resume_mode == "auto":
- if global_step_folder is None:
- print("Training from scratch")
- return 0
- else:
- if self.config.trainer.resume_mode == "resume_path":
- assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
- assert "global_step_" in self.config.trainer.resume_from_path, (
- "resume ckpt must specify the global_steps"
- )
- global_step_folder = self.config.trainer.resume_from_path
- if not os.path.isabs(global_step_folder):
- working_dir = os.getcwd()
- global_step_folder = os.path.join(working_dir, global_step_folder)
- print(f"Load from checkpoint folder: {global_step_folder}")
- # set global step
- self.global_steps = int(global_step_folder.split("global_step_")[-1])
-
- print(f"Setting global step to {self.global_steps}")
- print(f"Resuming from {global_step_folder}")
-
- actor_path = os.path.join(global_step_folder, "actor")
- critic_path = os.path.join(global_step_folder, "critic")
- # load actor
- self.actor_rollout_wg.load_checkpoint(
- actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
- )
- # load critic
- if self.use_critic:
- self.critic_wg.load_checkpoint(
- critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
- )
-
- # load dataloader,
- # TODO: from remote not implemented yet
- dataloader_local_path = os.path.join(global_step_folder, "data.pt")
- if os.path.exists(dataloader_local_path):
- dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)
- self.train_dataloader.load_state_dict(dataloader_state_dict)
- else:
- print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
-
- def _start_profiling(self, do_profile: bool) -> None:
- """Start profiling for all worker groups if profiling is enabled."""
- if do_profile:
- self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps)
- if self.use_reference_policy:
- self.ref_policy_wg.start_profile(profile_step=self.global_steps)
- if self.use_critic:
- self.critic_wg.start_profile(profile_step=self.global_steps)
-
- def _stop_profiling(self, do_profile: bool) -> None:
- """Stop profiling for all worker groups if profiling is enabled."""
- if do_profile:
- self.actor_rollout_wg.stop_profile()
- if self.use_reference_policy:
- self.ref_policy_wg.stop_profile()
- if self.use_critic:
- self.critic_wg.stop_profile()
-
- def _balance_batch(
- self, batch: BatchMeta, tq_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False
- ):
- """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens"""
- data = tq_client.get_data(batch)
-
- attention_mask = data["attention_mask"]
- batch_size = attention_mask.shape[0]
- global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,)
- global_seqlen_lst = calculate_workload(global_seqlen_lst)
- world_size = self.actor_rollout_wg.world_size
- if keep_minibatch:
- # Decouple the DP balancing and mini-batching.
- minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size", None)
- if minibatch_size is None:
- raise ValueError("'ppo_mini_batch_size' must be set in actor config when 'keep_minibatch' is True.")
- minibatch_num = len(global_seqlen_lst) // minibatch_size
- global_partition_lst = [[] for _ in range(world_size)]
- for i in range(minibatch_num):
- rearrange_minibatch_lst = get_seqlen_balanced_partitions(
- global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size],
- k_partitions=world_size,
- equal_size=True,
- )
- for j, part in enumerate(rearrange_minibatch_lst):
- global_partition_lst[j].extend([x + minibatch_size * i for x in part])
- else:
- global_partition_lst = get_seqlen_balanced_partitions(
- global_seqlen_lst, k_partitions=world_size, equal_size=True
- )
- # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
- for idx, partition in enumerate(global_partition_lst):
- partition.sort(key=lambda x: (global_seqlen_lst[x], x))
- ordered_partition = partition[::2] + partition[1::2][::-1]
- global_partition_lst[idx] = ordered_partition
- # reorder based on index. The data will be automatically equally partitioned by dispatch function
- global_idx = [j for partition in global_partition_lst for j in partition]
- global_balance_stats = log_seqlen_unbalance(
- seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
- )
- metrics.update(global_balance_stats)
- return global_idx
-
- @classmethod
- def repeat_dict(
- cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_times=2, interleave=True
- ) -> dict[str, torch.Tensor | np.ndarray]:
- """
- Repeat the batch dict a specified number of times.
-
- Args:
- repeat_times (int): Number of times to repeat the data.
- interleave (bool): Whether to interleave the repeated data.
-
- Returns:
- dict: A new dict with repeated data.
- """
- if repeat_times == 1:
- return batch_dict
-
- repeated_batch_dict = {}
- if batch_dict:
- if interleave:
- # Interleave the data
- for key, val in batch_dict.items():
- if isinstance(val, torch.Tensor):
- repeated_batch_dict[key] = val.repeat_interleave(repeat_times, dim=0)
- elif isinstance(val, np.ndarray):
- repeated_batch_dict[key] = np.repeat(val, repeat_times, axis=0)
- else:
- raise ValueError(f"Unsupported type in data {type(val)}")
- else:
- # Stack the data
- for key, val in batch_dict.items():
- if isinstance(val, torch.Tensor):
- repeated_batch_dict[key] = (
- val.unsqueeze(0).expand(repeat_times, *val.shape).reshape(-1, *val.shape[1:])
- )
- elif isinstance(val, np.ndarray):
- repeated_batch_dict[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))
- else:
- raise ValueError(f"Unsupported type in data {type(val)}")
- return repeated_batch_dict
-
- @classmethod
- def dict_to_tensordict(cls, data: dict[str, torch.Tensor | np.ndarray]) -> TensorDict:
- """
- Create a TensorDict from a dict of tensors and non_tensors.
- Note that this requires tensordict version at least 0.10
- """
- assert parse_version(tensordict.__version__) >= parse_version("0.10"), (
- "Storing non-tensor data in TensorDict at least requires tensordict version 0.10"
- )
- tensors_batch = {}
- batch_size = None
-
- for key, val in data.items():
- if isinstance(val, torch.Tensor | np.ndarray):
- tensors_batch[key] = val
- else:
- raise ValueError(f"Unsupported type in data {type(val)}")
-
- if batch_size is None:
- batch_size = len(val)
- else:
- assert len(val) == batch_size
-
- if batch_size is None:
- batch_size = []
- else:
- batch_size = [batch_size]
-
- return TensorDict(tensors_batch, batch_size=batch_size)
-
- def fit(self):
- """
- The training loop of PPO.
- The driver process only need to call the compute functions of the worker group through RPC
- to construct the PPO dataflow.
- The light-weight advantage computation is done on the driver process.
- """
- from omegaconf import OmegaConf
-
- from verl.utils.tracking import Tracking
-
- logger = Tracking(
- project_name=self.config.trainer.project_name,
- experiment_name=self.config.trainer.experiment_name,
- default_backend=self.config.trainer.logger,
- config=OmegaConf.to_container(self.config, resolve=True),
- )
-
- self.global_steps = 0
-
- # load checkpoint and update weights before doing anything
- self._load_checkpoint()
- self.checkpoint_manager.update_weights()
-
- # perform validation before training
- # currently, we only support validation using the reward_function.
- if self.config.trainer.get("val_before_train", True):
- val_metrics = self._validate()
- assert val_metrics, f"{val_metrics=}"
- pprint(f"Initial validation metrics: {val_metrics}")
- logger.log(data=val_metrics, step=self.global_steps)
- if self.config.trainer.get("val_only", False):
- return
-
- if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
- rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
- rollout_skip.wrap_generate_sequences()
-
- # add tqdm
- progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
-
- # we start from step 1
- self.global_steps += 1
- last_val_metrics = None
- self.max_steps_duration = 0
-
- prev_step_profile = False
- curr_step_profile = (
- self.global_steps in self.config.global_profiler.steps
- if self.config.global_profiler.steps is not None
- else False
- )
- next_step_profile = False
-
- for epoch in range(self.config.trainer.total_epochs):
- for batch_dict in self.train_dataloader:
- metrics = {}
- timing_raw = {}
- base_get_meta_kwargs = dict(
- batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
- partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1
- )
-
- with marked_timer("start_profile", timing_raw):
- self._start_profiling(
- not prev_step_profile and curr_step_profile
- if self.config.global_profiler.profile_continuous_steps
- else curr_step_profile
- )
-
- # add uid to batch
- batch_dict["uid"] = np.array(
- [str(uuid.uuid4()) for _ in range(len(batch_dict["raw_prompt"]))], dtype=object
- )
- # When n > 1, repeat input data before putting to data system, simulating DataProto repeat.
- repeated_batch_dict = self.repeat_dict(
- batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
- )
- batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict)
- gen_meta = self.tq_client.put(data=batch, partition_id=f"train_{self.global_steps - 1}")
-
- # pass global_steps to trace
- gen_meta.set_extra_info("global_steps", self.global_steps)
-
- is_last_step = self.global_steps >= self.total_training_steps
-
- with marked_timer("step", timing_raw):
- # generate a batch
- with marked_timer("gen", timing_raw, color="red"):
- if not self.async_rollout_mode:
- gen_output_meta = self.actor_rollout_wg.generate_sequences(gen_meta)
- else:
- gen_output_meta = self.async_rollout_manager.generate_sequences(gen_meta)
- self.checkpoint_manager.sleep_replicas()
- timing_raw.update(gen_output_meta.extra_info["timing"])
- gen_output_meta.extra_info.pop("timing", None)
-
- # TODO (TQ): support transfer queue
- # if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
- # if self.reward_fn is None:
- # raise ValueError("A reward_fn is required for REMAX advantage estimation.")
- #
- # with marked_timer("gen_max", timing_raw, color="purple"):
- # gen_baseline_meta = deepcopy(gen_meta)
- # gen_baseline_meta.extra_info["do_sample"] = False
- # if not self.async_rollout_mode:
- # gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_meta)
- # else:
- # gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_meta)
- # batch = batch.union(gen_baseline_output)
- # reward_baseline_tensor = self.reward_fn(batch)
- # reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
- #
- # batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
- #
- # batch.batch["reward_baselines"] = reward_baseline_tensor
- #
- # del gen_baseline_batch, gen_baseline_output
-
- batch_meta: BatchMeta = gen_meta.union(gen_output_meta)
-
- if "response_mask" not in batch_meta.field_names:
- response_mask_meta = self.tq_client.get_meta(
- data_fields=["responses", "attention_mask"],
- task_name="compute_response_mask",
- **base_get_meta_kwargs,
- )
- response_mask_output_meta = compute_response_mask(response_mask_meta, self.tq_client)
- batch_meta = batch_meta.union(response_mask_output_meta)
-
- # Balance the number of valid tokens across DP ranks.
- # NOTE: This usually changes the order of data in the `batch`,
- # which won't affect the advantage calculation (since it's based on uid),
- # but might affect the loss calculation (due to the change of mini-batching).
- # TODO: Decouple the DP balancing and mini-batching.
-
- attention_mask_meta = batch_meta.select_fields(["attention_mask"])
- balanced_idx = None
- if self.config.trainer.balance_batch:
- balanced_idx = self._balance_batch(attention_mask_meta, self.tq_client, metrics=metrics)
- batch_meta.reorder(balanced_idx)
-
- # compute global_valid tokens
- data = self.tq_client.get_data(attention_mask_meta)
- batch_meta.extra_info["global_token_num"] = torch.sum(data["attention_mask"], dim=-1).tolist()
-
- with marked_timer("reward", timing_raw, color="yellow"):
- # compute reward model score
- if self.use_rm and "rm_scores" not in batch_meta.field_names:
- reward_meta = self.rm_wg.compute_rm_score(batch_meta)
- batch_meta = batch_meta.union(reward_meta)
-
- compute_reward_fields = [
- "responses",
- "prompts",
- "attention_mask",
- "reward_model",
- "data_source",
- ]
- if "rm_scores" in batch_meta.field_names:
- compute_reward_fields.extend(
- ["rm_scores", *set(batch_meta.extra_info["reward_extra_keys"])]
- )
-
- reward_tensor, reward_extra_infos_dict = compute_reward_decorated(batch_meta)
-
- compute_reward_meta = batch_meta.select_fields(compute_reward_fields)
- batch_meta = batch_meta.union(compute_reward_meta)
-
- # recompute old_log_probs
- with marked_timer("old_log_prob", timing_raw, color="blue"):
- old_log_prob_meta_fields = [
- "input_ids",
- "attention_mask",
- "position_ids",
- "prompts",
- "responses",
- "response_mask",
- "data_source",
- "reward_model",
- "extra_info",
- "uid",
- "index",
- "tools_kwargs",
- "interaction_kwargs",
- "ability",
- ]
- old_log_prob_meta = batch_meta.select_fields(old_log_prob_meta_fields)
- old_log_prob_output_meta = self.actor_rollout_wg.compute_log_prob(old_log_prob_meta)
- batch_meta = batch_meta.union(old_log_prob_output_meta)
-
- data = self.tq_client.get_data(old_log_prob_output_meta)
- entropys = data["entropys"]
- response_masks = data["response_mask"]
- actor_config = self.config.actor_rollout_ref.actor
- entropy_agg = agg_loss(
- loss_mat=entropys,
- loss_mask=response_masks,
- loss_agg_mode=actor_config.loss_agg_mode,
- loss_scale_factor=actor_config.loss_scale_factor,
- )
- old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
- metrics.update(old_log_prob_metrics)
-
- if "rollout_log_probs" in batch_meta.field_names:
- # TODO: we may want to add diff of probs too.
- calculate_debug_metrics_fields = ["rollout_log_probs", "old_log_probs", "responses"]
-
- if "response_mask" in batch_meta.field_names:
- calculate_debug_metrics_fields.append("response_mask")
- if "attention_mask" in batch_meta.field_names:
- calculate_debug_metrics_fields.append("attention_mask")
-
- calculate_debug_metrics_meta = batch_meta.select_fields(calculate_debug_metrics_fields)
- metrics.update(calculate_debug_metrics_decorated(calculate_debug_metrics_meta))
-
- if self.use_reference_policy:
- # compute reference log_prob
- ref_log_prob_fields = [
- "input_ids",
- "attention_mask",
- "position_ids",
- "prompts",
- "responses",
- "response_mask",
- "old_log_probs",
- "data_source",
- "reward_model",
- "extra_info",
- "uid",
- "index",
- "tools_kwargs",
- "interaction_kwargs",
- "ability",
- ]
- ref_log_prob_meta = batch_meta.select_fields(ref_log_prob_fields)
-
- with marked_timer("ref", timing_raw, color="olive"):
- if not self.ref_in_actor:
- ref_log_prob_output_meta = self.ref_policy_wg.compute_ref_log_prob(ref_log_prob_meta)
- else:
- ref_log_prob_output_meta = self.actor_rollout_wg.compute_ref_log_prob(ref_log_prob_meta)
- batch_meta = batch_meta.union(ref_log_prob_output_meta)
-
- # compute values
- if self.use_critic:
- with marked_timer("values", timing_raw, color="cyan"):
- values_meta = self.critic_wg.compute_values(batch_meta)
- batch_meta = batch_meta.union(values_meta)
-
- with marked_timer("adv", timing_raw, color="brown"):
- # we combine with rule-based rm
- reward_extra_infos_dict: dict[str, list]
- reward_td = TensorDict({"token_level_scores": reward_tensor}, batch_size=reward_tensor.size(0))
- batch_meta = self.tq_client.put(data=reward_td, metadata=batch_meta)
-
- if reward_extra_infos_dict:
- reward_extra_infos_dict_new = {k: np.array(v) for k, v in reward_extra_infos_dict.items()}
- reward_extra_infos_td = self.dict_to_tensordict(reward_extra_infos_dict_new)
- batch_meta = self.tq_client.put(data=reward_extra_infos_td, metadata=batch_meta)
-
- # compute rewards. apply_kl_penalty if available
- if self.config.algorithm.use_kl_in_reward:
- apply_kl_penalty_fields = [
- "response_mask",
- "token_level_scores",
- "old_log_probs",
- "ref_log_prob",
- ]
-
- apply_kl_penalty_meta = batch_meta.select_fields(apply_kl_penalty_fields)
-
- token_level_rewards, kl_metrics = apply_kl_penalty(
- apply_kl_penalty_meta,
- kl_ctrl=self.kl_ctrl_in_reward,
- kl_penalty=self.config.algorithm.kl_penalty,
- )
- token_level_rewards_td = TensorDict(
- {"token_level_rewards": token_level_rewards}, batch_size=token_level_rewards.size(0)
- )
- apply_kl_penalty_meta = self.tq_client.put(
- data=token_level_rewards_td, metadata=apply_kl_penalty_meta
- )
-
- metrics.update(kl_metrics)
- batch_meta = batch_meta.union(apply_kl_penalty_meta)
- else:
- token_level_scores_meta = batch_meta.select_fields(["token_level_scores"])
-
- data = self.tq_client.get_data(token_level_scores_meta)
- token_level_rewards_td = TensorDict(
- {"token_level_rewards": data["token_level_scores"]},
- batch_size=data["token_level_scores"].size(0),
- )
- token_level_scores_meta = self.tq_client.put(
- data=token_level_rewards_td, metadata=token_level_scores_meta
- )
- batch_meta = batch_meta.union(token_level_scores_meta)
-
- # compute advantages, executed on the driver process
-
- norm_adv_by_std_in_grpo = self.config.algorithm.get(
- "norm_adv_by_std_in_grpo", True
- ) # GRPO adv normalization factor
-
- assert "response_mask" in batch_meta.field_names, (
- f"`response_mask` must be in batch_meta {batch_meta.field_names} for advantage computation"
- )
- compute_advantage_fields = [
- "response_mask",
- "token_level_rewards",
- ]
- if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
- compute_advantage_fields.append("values")
- elif self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO:
- compute_advantage_fields.append("uid")
- else:
- if "uid" in batch_meta.field_names:
- compute_advantage_fields.append("uid")
- if "reward_baselines" in batch_meta.field_names:
- compute_advantage_fields.append("reward_baselines")
-
- compute_advantage_meta = batch_meta.select_fields(compute_advantage_fields)
-
- advantages, returns = compute_advantage(
- compute_advantage_meta,
- adv_estimator=self.config.algorithm.adv_estimator,
- gamma=self.config.algorithm.gamma,
- lam=self.config.algorithm.lam,
- num_repeat=self.config.actor_rollout_ref.rollout.n,
- norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
- config=self.config.algorithm,
- )
-
- advantages_td = TensorDict(
- {"advantages": advantages, "returns": returns}, batch_size=advantages.size(0)
- )
- compute_advantage_meta = self.tq_client.put(data=advantages_td, metadata=compute_advantage_meta)
- batch_meta = batch_meta.union(compute_advantage_meta)
-
- # update critic
- if self.use_critic:
- with marked_timer("update_critic", timing_raw, color="pink"):
- critic_output_meta = self.critic_wg.update_critic(batch_meta)
- batch_meta = batch_meta.union(critic_output_meta)
- critic_output_metrics = reduce_metrics(critic_output_meta.extra_info["metrics"])
- metrics.update(critic_output_metrics)
-
- # implement critic warmup
- if self.config.trainer.critic_warmup <= self.global_steps:
- # update actor
- with marked_timer("update_actor", timing_raw, color="red"):
- batch_meta.extra_info["multi_turn"] = (
- self.config.actor_rollout_ref.rollout.multi_turn.enable
- )
-
- update_actor_fields = [
- "input_ids",
- "attention_mask",
- "position_ids",
- "prompts",
- "responses",
- "response_mask",
- "old_log_probs",
- "ref_log_prob",
- "advantages",
- "returns",
- "token_level_rewards",
- "token_level_scores",
- "data_source",
- "reward_model",
- "extra_info",
- "uid",
- "index",
- "tools_kwargs",
- "interaction_kwargs",
- "ability",
- ]
- update_actor_meta = batch_meta.select_fields(update_actor_fields)
-
- update_actor_meta.set_extra_info(
- "global_token_num", batch_meta.get_extra_info("global_token_num")
- )
- update_actor_meta.set_extra_info("temperature", batch_meta.get_extra_info("temperature"))
-
- actor_output_meta = self.actor_rollout_wg.update_actor(update_actor_meta)
- batch_meta = batch_meta.union(actor_output_meta)
-
- # update weights from trainer to rollout
- with marked_timer("update_weights", timing_raw, color="red"):
- self.checkpoint_manager.update_weights()
-
- actor_output_metrics = reduce_metrics(actor_output_meta.extra_info["metrics"])
- metrics.update(actor_output_metrics)
-
- # Log rollout generations if enabled
- rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
- if rollout_data_dir:
- log_rollout_fields = ["prompts", "responses", "token_level_scores", "reward_model"]
- if "request_id" in batch_meta.field_names:
- log_rollout_fields.append("request_id")
- log_rollout_meta = batch_meta.select_fields(log_rollout_fields)
- self._log_rollout_data(log_rollout_meta, reward_extra_infos_dict, timing_raw, rollout_data_dir)
-
- # TODO: validate
- if self.config.trainer.test_freq > 0 and (
- is_last_step or self.global_steps % self.config.trainer.test_freq == 0
- ):
- with marked_timer("testing", timing_raw, color="green"):
- val_metrics: dict = self._validate()
- if is_last_step:
- last_val_metrics = val_metrics
- metrics.update(val_metrics)
-
- # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
- esi_close_to_expiration = should_save_ckpt_esi(
- max_steps_duration=self.max_steps_duration,
- redundant_time=self.config.trainer.esi_redundant_time,
- )
- # Check if the conditions for saving a checkpoint are met.
- # The conditions include a mandatory condition (1) and
- # one of the following optional conditions (2/3/4):
- # 1. The save frequency is set to a positive value.
- # 2. It's the last training step.
- # 3. The current step number is a multiple of the save frequency.
- # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
- if self.config.trainer.save_freq > 0 and (
- is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
- ):
- if esi_close_to_expiration:
- print("Force saving checkpoint: ESI instance expiration approaching.")
- with marked_timer("save_checkpoint", timing_raw, color="green"):
- self._save_checkpoint()
-
- with marked_timer("stop_profile", timing_raw):
- next_step_profile = (
- self.global_steps + 1 in self.config.global_profiler.steps
- if self.config.global_profiler.steps is not None
- else False
- )
- self._stop_profiling(
- curr_step_profile and not next_step_profile
- if self.config.global_profiler.profile_continuous_steps
- else curr_step_profile
- )
- prev_step_profile = curr_step_profile
- curr_step_profile = next_step_profile
-
- steps_duration = timing_raw["step"]
- self.max_steps_duration = max(self.max_steps_duration, steps_duration)
-
- # training metrics
- metrics.update(
- {
- "training/global_step": self.global_steps,
- "training/epoch": epoch,
- }
- )
- # collect metrics
- compute_data_metrics_fields = [
- "token_level_rewards",
- "token_level_scores",
- "advantages",
- "returns",
- "responses",
- "attention_mask",
- "response_mask",
- ]
- if "__num_turns__" in batch_meta.field_names:
- compute_data_metrics_fields.append("__num_turns__")
- if "tool_call_counts" in batch_meta.field_names:
- compute_data_metrics_fields.append("tool_call_counts")
- compute_data_metrics_meta = batch_meta.select_fields(compute_data_metrics_fields)
- compute_data_metrics_meta.reorder(balanced_idx)
- metrics.update(
- compute_data_metrics_decorated(batch=compute_data_metrics_meta, use_critic=self.use_critic)
- )
-
- compute_timing_metrics_fields = ["responses", "attention_mask"]
- compute_timing_metrics_meta = batch_meta.select_fields(compute_timing_metrics_fields)
- compute_timing_metrics_meta.reorder(balanced_idx)
- metrics.update(
- compute_timing_metrics_decorated(batch=compute_timing_metrics_meta, timing_raw=timing_raw)
- )
-
- compute_throughout_metrics_meta = BatchMeta(
- samples=[],
- extra_info={"global_token_num": batch_meta.get_extra_info("global_token_num")},
- )
- # TODO: implement actual tflpo and theoretical tflpo
- n_gpus = self.resource_pool_manager.get_n_gpus()
- metrics.update(
- compute_throughout_metrics_decorated(
- batch=compute_throughout_metrics_meta, timing_raw=timing_raw, n_gpus=n_gpus
- )
- )
-
- # this is experimental and may be changed/removed in the future in favor of a general-purpose one
- if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
- # TODO (TQ) :support transfer queue
- self.train_dataloader.sampler.update(batch=batch)
-
- self.tq_client.clear_samples(batch_meta)
- # TODO: make a canonical logger that supports various backend
- logger.log(data=metrics, step=self.global_steps)
-
- progress_bar.update(1)
- self.global_steps += 1
-
- if (
- hasattr(self.config.actor_rollout_ref.actor, "profiler")
- and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
- ):
- self.actor_rollout_wg.dump_memory_snapshot(
- tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
- )
-
- if is_last_step:
- pprint(f"Final validation metrics: {last_val_metrics}")
- progress_bar.close()
- return
-
- # this is experimental and may be changed/removed in the future
- # in favor of a general-purpose data buffer pool
- if hasattr(self.train_dataset, "on_batch_end"):
- # The dataset may be changed after each training batch
- # TODO (TQ): support transfer queue
- self.train_dataset.on_batch_end(batch=batch)
diff --git a/verl/experimental/transfer_queue/run_qwen3-8b_transferqueue.sh b/verl/experimental/transfer_queue/run_qwen3-8b_transferqueue.sh
deleted file mode 100644
index bd6d09e32d7..00000000000
--- a/verl/experimental/transfer_queue/run_qwen3-8b_transferqueue.sh
+++ /dev/null
@@ -1,70 +0,0 @@
-set -x
-
-MODEL_PATH="/workspace/models/Qwen3-8B"
-TRAIN_FILE="/workspace/datasets/preprocessed/gsm8k/train.parquet"
-TEST_FILE="/workspace/datasets/preprocessed/gsm8k/test.parquet"
-
-log_dir="./logs"
-mkdir -p ${log_dir}
-timestamp=$(date +"%Y%m%d%H%M%S")
-log_file="${log_dir}/qwen3-8b_tq_${timestamp}.log"
-
-# You may try to enable zero-copy serialization for TransferQueue when using SimpleStorageUnit backend.
-export TQ_ZERO_COPY_SERIALIZATION=False
-
-rollout_mode="async"
-rollout_name="vllm" # sglang or vllm
-if [ "$rollout_mode" = "async" ]; then
- export VLLM_USE_V1=1
- return_raw_chat="True"
-fi
-
-# You may also refer to tests/special_e2e/run_transferqueue.sh for more demo scripts
-
-python3 -m verl.experimental.transfer_queue.main_ppo \
- --config-name='transfer_queue_ppo_trainer' \
- algorithm.adv_estimator=grpo \
- data.train_files=${TRAIN_FILE} \
- data.val_files=${TEST_FILE} \
- data.return_raw_chat=$return_raw_chat \
- data.train_batch_size=128 \
- data.max_prompt_length=2048 \
- data.max_response_length=8192 \
- data.filter_overlong_prompts_workers=128 \
- data.filter_overlong_prompts=True \
- data.truncation='error' \
- actor_rollout_ref.model.path=${MODEL_PATH} \
- actor_rollout_ref.actor.optim.lr=1e-6 \
- actor_rollout_ref.model.use_remove_padding=True \
- actor_rollout_ref.actor.ppo_mini_batch_size=32 \
- actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
- actor_rollout_ref.actor.use_kl_loss=True \
- actor_rollout_ref.actor.kl_loss_coef=0.001 \
- actor_rollout_ref.actor.kl_loss_type=low_var_kl \
- actor_rollout_ref.actor.entropy_coeff=0 \
- actor_rollout_ref.model.enable_gradient_checkpointing=True \
- actor_rollout_ref.actor.fsdp_config.param_offload=True \
- actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
- actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
- actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
- actor_rollout_ref.rollout.max_num_batched_tokens=10240 \
- actor_rollout_ref.rollout.name=$rollout_name \
- actor_rollout_ref.rollout.mode=$rollout_mode \
- actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
- actor_rollout_ref.rollout.n=5 \
- actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
- actor_rollout_ref.ref.fsdp_config.param_offload=True \
- algorithm.use_kl_in_reward=False \
- trainer.critic_warmup=0 \
- trainer.logger=console \
- trainer.project_name='verl_grpo_example_gsm8k' \
- trainer.experiment_name='qwen3_8b_function_rm' \
- trainer.n_gpus_per_node=8 \
- trainer.nnodes=1 \
- trainer.save_freq=-1 \
- trainer.test_freq=1000 \
- trainer.total_epochs=15 \
- trainer.total_training_steps=2 \
- trainer.val_before_train=False \
- 2>&1 | tee "$log_file"
-echo "Finished, log is saved in: $log_file"
\ No newline at end of file
diff --git a/verl/experimental/vla/rob_ray_trainer.py b/verl/experimental/vla/rob_ray_trainer.py
index 619c822c75d..519b7dab4b5 100644
--- a/verl/experimental/vla/rob_ray_trainer.py
+++ b/verl/experimental/vla/rob_ray_trainer.py
@@ -600,13 +600,10 @@ def _validate(self):
# pad to be divisible by dp_size
size_divisor = self.config.env.train.num_envs * self.config.env.rollout.pipeline_stage_num
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
- if not self.async_rollout_mode:
- test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
- else:
- reset_future = self._reset_envs(test_gen_batch_padded)
- test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(
- test_gen_batch_padded, reset_future
- )
+ reset_future = self._reset_envs(test_gen_batch_padded)
+ test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(
+ test_gen_batch_padded, reset_future
+ )
# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
diff --git a/verl/experimental/vla/sac/sac_ray_trainer.py b/verl/experimental/vla/sac/sac_ray_trainer.py
index cbdc1d597a2..310a9dc7686 100644
--- a/verl/experimental/vla/sac/sac_ray_trainer.py
+++ b/verl/experimental/vla/sac/sac_ray_trainer.py
@@ -528,13 +528,10 @@ def _validate(self):
# pad to be divisible by dp_size
size_divisor = self.config.env.train.num_envs * self.config.env.rollout.pipeline_stage_num
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
- if not self.async_rollout_mode:
- test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
- else:
- reset_future = self._reset_envs(test_gen_batch_padded)
- test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(
- test_gen_batch_padded, reset_future
- )
+ reset_future = self._reset_envs(test_gen_batch_padded)
+ test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(
+ test_gen_batch_padded, reset_future
+ )
# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
diff --git a/verl/models/mcore/mtp_patch.py b/verl/models/mcore/mtp_patch.py
index 117b6e3f28c..fadf5b7bd52 100644
--- a/verl/models/mcore/mtp_patch.py
+++ b/verl/models/mcore/mtp_patch.py
@@ -20,11 +20,7 @@
import torch
from megatron.core import parallel_state
from megatron.core.models.gpt.gpt_model import GPTModel
-from megatron.core.transformer.multi_token_prediction import (
- MTPLossAutoScaler,
- MTPLossLoggingHelper,
- roll_tensor,
-)
+from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor
try:
from megatron.core.utils import unwrap_model
@@ -78,19 +74,45 @@ def _megatron_gptmodel_postprocess(
runtime_gather_output=None,
extra_block_kwargs=None,
inference_context=None,
+ **kwargs,
):
- """Postprocesses decoder hidden states to generate logits or compute loss.
+ """Compatibility patch for GPTModel._postprocess.
- Applies Multi-Token Prediction if enabled, generates output logits through
- the output layer, and computes language model loss when labels are provided.
+ For inference (`labels is None`), delegate to the upstream implementation to stay
+ aligned with Megatron-Core updates.
+
+ For training (`labels is not None`), keep VERL's MTP behavior and always return
+ logits (instead of CE loss) so PPO paths can compute custom losses from logits.
"""
+ # Keep inference path aligned with whatever upstream Megatron currently expects.
+ if labels is None:
+ return self._postprocess_backup(
+ hidden_states=hidden_states,
+ input_ids=input_ids,
+ position_ids=position_ids,
+ labels=labels,
+ rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_cos=rotary_pos_cos,
+ rotary_pos_sin=rotary_pos_sin,
+ mtp_in_postprocess=mtp_in_postprocess,
+ loss_mask=loss_mask,
+ decoder_input=decoder_input,
+ attention_mask=attention_mask,
+ inference_params=inference_params,
+ packed_seq_params=packed_seq_params,
+ sequence_len_offset=sequence_len_offset,
+ runtime_gather_output=runtime_gather_output,
+ extra_block_kwargs=extra_block_kwargs,
+ inference_context=inference_context,
+ **kwargs,
+ )
- # logits and loss
+ # Training path: keep logits for external loss computation.
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
- if mtp_in_postprocess and labels is not None:
+ if mtp_in_postprocess:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
@@ -109,60 +131,85 @@ def _megatron_gptmodel_postprocess(
if not self.post_process:
return hidden_states
- # Skip when mtp_num_layers is None or 0
- if self.config.mtp_num_layers and labels is not None:
- mtp_labels = labels.clone()
-
- hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
- hidden_states = hidden_states_list[0]
- if loss_mask is None:
- # if loss_mask is not provided, use all ones as loss_mask
- loss_mask = torch.ones_like(mtp_labels)
- for mtp_layer_number in range(self.config.mtp_num_layers):
- # Calc loss for the current Multi-Token Prediction (MTP) layers.
- mtp_labels, _ = roll_tensor(
- mtp_labels,
- shifts=-1,
- dims=-1,
- cp_group=self.cp_group,
+ # Skip when mtp_num_layers is None or 0.
+ if self.config.mtp_num_layers:
+ cp_group = None
+ if getattr(self, "pg_collection", None) is not None:
+ cp_group = self.pg_collection.cp
+ elif hasattr(self, "cp_group"):
+ cp_group = self.cp_group
+
+ # Prefer upstream helper when available (newer Megatron-LM).
+ try:
+ from megatron.core.transformer.multi_token_prediction import process_mtp_loss
+
+ hidden_states = process_mtp_loss(
+ hidden_states=hidden_states,
+ labels=labels,
+ loss_mask=loss_mask,
+ output_layer=self.output_layer,
+ output_weight=output_weight,
+ runtime_gather_output=runtime_gather_output,
+ is_training=self.training,
+ compute_language_model_loss=self.compute_language_model_loss,
+ config=self.config,
+ cp_group=cp_group,
packed_seq_params=packed_seq_params,
)
- loss_mask, num_tokens = roll_tensor(
- loss_mask,
- shifts=-1,
- dims=-1,
- cp_group=self.cp_group,
- packed_seq_params=packed_seq_params,
- )
-
- # Compute mtp loss without storing logits to save memory.
- mtp_loss = self.compute_output_layer_and_language_model_loss(
- hidden_states_list[mtp_layer_number + 1],
- labels=mtp_labels,
- weight=self.shared_embedding_or_output_weight(),
- sequence_parallel_enabled=self.output_layer.sequence_parallel,
- column_parallel_linear=self.output_layer,
- col_linear_kwargs={
- "weight": output_weight,
- "runtime_gather_output": runtime_gather_output,
- },
- )
+ except (ImportError, AttributeError, TypeError):
+ # Fallback for older Megatron-LM versions without process_mtp_loss API.
+ mtp_labels = labels.clone()
+
+ hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
+ hidden_states = hidden_states_list[0]
+ if loss_mask is None:
+ # if loss_mask is not provided, use all ones as loss_mask
+ loss_mask = torch.ones_like(mtp_labels)
+ for mtp_layer_number in range(self.config.mtp_num_layers):
+ # Calc loss for the current Multi-Token Prediction (MTP) layers.
+ mtp_labels, _ = roll_tensor(
+ mtp_labels,
+ shifts=-1,
+ dims=-1,
+ cp_group=self.cp_group,
+ packed_seq_params=packed_seq_params,
+ )
+ loss_mask, num_tokens = roll_tensor(
+ loss_mask,
+ shifts=-1,
+ dims=-1,
+ cp_group=self.cp_group,
+ packed_seq_params=packed_seq_params,
+ )
- mtp_loss = loss_mask * mtp_loss
- if self.training:
- # TODO(shifangx): remove the use of parallel_state here
- # after moving loss logging to loss_func in pretrain_gpt.py
- MTPLossLoggingHelper.save_loss_to_tracker(
- torch.sum(mtp_loss) / num_tokens,
- mtp_layer_number,
- self.config.mtp_num_layers,
- avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
+ # Compute mtp loss without storing logits to save memory.
+ mtp_loss = self.compute_output_layer_and_language_model_loss(
+ hidden_states_list[mtp_layer_number + 1],
+ labels=mtp_labels,
+ weight=self.shared_embedding_or_output_weight(),
+ sequence_parallel_enabled=self.output_layer.sequence_parallel,
+ column_parallel_linear=self.output_layer,
+ col_linear_kwargs={
+ "weight": output_weight,
+ "runtime_gather_output": runtime_gather_output,
+ },
)
- mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
- if self.config.calculate_per_token_loss:
- hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
- else:
- hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
+
+ mtp_loss = loss_mask * mtp_loss
+ if self.training:
+ # TODO(shifangx): remove the use of parallel_state here
+ # after moving loss logging to loss_func in pretrain_gpt.py
+ MTPLossLoggingHelper.save_loss_to_tracker(
+ torch.sum(mtp_loss) / num_tokens,
+ mtp_layer_number,
+ self.config.mtp_num_layers,
+ avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
+ )
+ mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
+ if self.config.calculate_per_token_loss:
+ hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
+ else:
+ hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
# [s b h] => [b s h]
diff --git a/verl/models/mcore/util.py b/verl/models/mcore/util.py
index aefb798aa0b..dc35310c894 100644
--- a/verl/models/mcore/util.py
+++ b/verl/models/mcore/util.py
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import math
+import os
import torch
from megatron.core import parallel_state as mpu
@@ -21,6 +23,9 @@
from verl.utils.model import CausalLMOutputForPPO
+logger = logging.getLogger(__file__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+
def preprocess_packed_seqs(
input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True, use_fp8_padding=False
@@ -333,6 +338,19 @@ def preprocess_thd_no_padding(
start_idx = cu_seqlens_padded_cpu[i] // cp_size
# split to 2 chunks
d = input_ids[i]
+ # If the number of elements in `d` is smaller than the required
+ # alignment size, pad the tensor with zeros so that its total
+ # length matches `align_size`. This ensures size alignment for
+ # downstream operations (e.g., communication or memory alignment).
+ if d.numel() < align_size:
+ original_size = d.numel()
+ pad = torch.zeros(align_size - d.numel(), dtype=d.dtype, device=d.device)
+ d = torch.cat([d, pad], dim=0)
+ logger.warning_once(
+ f"Padding tensor for context parallel alignment, original_size={original_size}, "
+ f"align_size={align_size}"
+ )
+
input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[
half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)
]
diff --git a/verl/protocol.py b/verl/protocol.py
index 27a1f6a1f94..820b9bd3462 100644
--- a/verl/protocol.py
+++ b/verl/protocol.py
@@ -36,7 +36,7 @@
from torch.utils.data import DataLoader
from verl.utils.device import get_device_id, get_torch_device
-from verl.utils.py_functional import union_two_dict
+from verl.utils.py_functional import list_of_dict_to_dict_of_list, union_two_dict
from verl.utils.torch_functional import allgather_dict_tensors
__all__ = ["DataProto", "union_tensor_dict"]
@@ -198,18 +198,6 @@ def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str
return tensor_dict1
-def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
- if len(list_of_dict) == 0:
- return {}
- keys = list_of_dict[0].keys()
- output = {key: [] for key in keys}
- for data in list_of_dict:
- for key, item in data.items():
- assert key in output
- output[key].append(item)
- return output
-
-
def fold_batch_dim(data: "DataProto", new_batch_size):
"""
Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]
diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py
index 540c4e00552..c8438309879 100644
--- a/verl/single_controller/base/decorator.py
+++ b/verl/single_controller/base/decorator.py
@@ -20,7 +20,6 @@
from verl.protocol import DataProtoFuture, _padding_size_key
from verl.utils.py_functional import DynamicEnum
from verl.utils.tensordict_utils import chunk_tensordict, concat_tensordict, contiguous
-from verl.utils.transferqueue_utils import BatchMeta
# here we add a magic number of avoid user-defined function already have this attribute
MAGIC_ATTR = "attrs_3141562937"
@@ -80,7 +79,7 @@ def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
splitted_args = []
for arg in args:
- assert isinstance(arg, DataProto | DataProtoFuture | BatchMeta | TensorDict)
+ assert isinstance(arg, DataProto | DataProtoFuture | TensorDict)
if isinstance(arg, TensorDict):
chunked_arg = chunk_tensordict(arg, chunks)
chunked_arg = _consolidate_tuple_td(chunked_arg)
@@ -91,7 +90,7 @@ def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
splitted_kwargs = {}
for key, val in kwargs.items():
- assert isinstance(val, DataProto | DataProtoFuture | BatchMeta | TensorDict)
+ assert isinstance(val, DataProto | DataProtoFuture | TensorDict)
if isinstance(val, TensorDict):
chunked_kwarg = chunk_tensordict(val, chunks)
chunked_kwarg = _consolidate_tuple_td(chunked_kwarg)
@@ -165,8 +164,6 @@ def _concat_data_proto_or_future(output: list):
return DataProto.concat(output)
elif isinstance(o, ray.ObjectRef):
return DataProtoFuture.concat(output)
- elif isinstance(o, BatchMeta):
- return BatchMeta.concat(output)
elif isinstance(o, TensorDict):
return concat_tensordict(output)
else:
@@ -288,8 +285,8 @@ def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output)
from verl.protocol import DataProto
for o in output:
- assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta | TensorDict), (
- f"expecting {o} to be DataProto | ray.ObjectRef | BatchMeta | TensorDict, but got {type(o)}"
+ assert isinstance(o, DataProto | ray.ObjectRef | TensorDict), (
+ f"expecting {o} to be DataProto | ray.ObjectRef | TensorDict, but got {type(o)}"
)
return _concat_data_proto_or_future(output)
@@ -447,14 +444,11 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki
A decorator that wraps the original function with distributed execution
configuration.
"""
- from verl.utils.transferqueue_utils import tqbridge
_check_dispatch_mode(dispatch_mode=dispatch_mode)
_check_execute_mode(execute_mode=execute_mode)
def decorator(func):
- func = tqbridge(dispatch_mode=dispatch_mode)(func)
-
@wraps(func)
def inner(*args, **kwargs):
if materialize_futures:
diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py
index cffaf5d30a0..6fc955cb049 100644
--- a/verl/single_controller/base/worker.py
+++ b/verl/single_controller/base/worker.py
@@ -165,15 +165,6 @@ def set_dispatch_collect(self, mesh_name: str, dispatch_dp_rank: dict[str, int],
for is_collect in collect_dp_rank.values():
self.__collect_dp_rank[mesh_name] = is_collect
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
- def create_transferqueue_client(self, config):
- from verl.utils.transferqueue_utils import create_transferqueue_client
-
- create_transferqueue_client(
- client_id=f"worker_{self.rank}",
- config=config.transfer_queue,
- )
-
@classmethod
def env_keys(cls):
"""The keys of the environment variables that are used to configure the Worker."""
diff --git a/verl/trainer/README.md b/verl/trainer/README.md
new file mode 100644
index 00000000000..d881d689b58
--- /dev/null
+++ b/verl/trainer/README.md
@@ -0,0 +1,16 @@
+# verl Main Entrypoints
+
+## SFT Trainer
+- sft_trainer.py: SFT trainer based on model engine, support various backends: fsdp, megatron, veomni, torchtitan. Launched by `torchrun` and run in multi-controller mode.
+- **[EXPERIMENTAL]** sft_trainer_ray.py: SFT trainer based on model engine with single-controller mode. Launched by ray with a driver process coordinating multiple worker processes.
+
+## RL Trainer
+|trainer|description|sync/async|trainer/rollout|partial rollout|
+|----|----|----|----|----|
+|main_ppo.py|rollout until a batch is completed, then train|synchronous|colocated|No|
+|TBD|[kimi-1.5](https://arxiv.org/pdf/2501.12599) style trainer: streaming rollout with capped length partial rollout|asynchronous|colocated|Yes|
+|TBD|[Areal](https://arxiv.org/pdf/2505.24298) style trainer: fully decoupled trainer and rollout with staleness control|asynchronous|disaggregated|Yes|
+
+## Inference and Evaluation
+- main_generation_server.py: Launch standalone servers and generate responses for a specified prompt dataset.
+- main_eval.py: Evaluate the performance of generated responses with reward function on a specified prompt dataset.
diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml
index a6e80b62cd2..09391ec6af3 100644
--- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml
+++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml
@@ -216,6 +216,8 @@ actor_rollout_ref:
_target_: verl.workers.config.RolloutConfig
name: ???
mode: async
+ nnodes: 0
+ n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8}
temperature: 1.0
top_k: -1
top_p: 1
@@ -290,6 +292,8 @@ actor_rollout_ref:
engine_kwargs: {}
trace:
_target_: verl.workers.config.TraceConfig
+ project_name: ${oc.select:trainer.project_name,null}
+ experiment_name: ${oc.select:trainer.experiment_name,null}
backend: null
token2text: false
max_samples_per_step_per_worker: null
@@ -324,6 +328,7 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+ qat: ${oc.select:actor_rollout_ref.actor.qat,null}
layer_name_map:
qkv_layer_name: qkv
gate_proj_layer_name: gate_up
@@ -585,6 +590,10 @@ reward_model:
reward_loop_source: null
reward_loop_module_path: null
reward_loop_class_name: null
+ model:
+ path: null
+ external_lib: null
+ trust_remote_code: null
rollout:
name: null
dtype: null
diff --git a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml
new file mode 100644
index 00000000000..b923da853ec
--- /dev/null
+++ b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml
@@ -0,0 +1,677 @@
+# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'
+# in which it invokes 'python3 scripts/print_cfg.py --cfg job model_engine=torchtitan' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file.
+# Do not modify this file directly.
+# The file is usually only for reference and never used.
+
+actor_rollout_ref:
+ actor:
+ optim:
+ _target_: verl.workers.config.TorchtitanOptimizerConfig
+ name: AdamW
+ lr: 1.0e-06
+ lr_warmup_steps_ratio: 0.0
+ total_training_steps: -1
+ weight_decay: 0.01
+ lr_warmup_steps: -1
+ betas:
+ - 0.9
+ - 0.999
+ clip_grad: 1.0
+ eps: 1.0e-08
+ decay_type: linear
+ min_lr_factor: 0.0
+ torchtitan:
+ _target_: verl.workers.config.TorchtitanEngineConfig
+ param_offload: false
+ optimizer_offload: false
+ wrap_policy:
+ min_num_params: 0
+ reshard_after_forward: default
+ forward_prefetch: false
+ use_orig_params: false
+ mixed_precision: false
+ use_torch_compile: true
+ entropy_from_logits_with_chunking: false
+ entropy_checkpointing: false
+ data_parallel_size: 1
+ data_parallel_replicate_size: 1
+ data_parallel_shard_size: 1
+ tensor_parallel_size: 1
+ expert_parallel_size: 1
+ pipeline_parallel_size: 1
+ context_parallel_size: 1
+ attn_type: flex
+ strategy: torchtitan
+ seed: 42
+ full_determinism: false
+ forward_only: false
+ dtype: bfloat16
+ _target_: verl.workers.config.TorchTitanActorConfig
+ rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
+ strategy: torchtitan
+ ppo_mini_batch_size: 256
+ ppo_micro_batch_size: null
+ ppo_micro_batch_size_per_gpu: null
+ use_dynamic_bsz: false
+ ppo_max_token_len_per_gpu: 16384
+ clip_ratio: 0.2
+ clip_ratio_low: 0.2
+ clip_ratio_high: 0.2
+ tau_pos: 1.0
+ tau_neg: 1.05
+ freeze_vision_tower: false
+ policy_loss:
+ _target_: verl.workers.config.PolicyLossConfig
+ loss_mode: vanilla
+ clip_cov_ratio: 0.0002
+ clip_cov_lb: 1.0
+ clip_cov_ub: 5.0
+ kl_cov_ratio: 0.0002
+ ppo_kl_coef: 0.1
+ clip_ratio_c: 3.0
+ loss_agg_mode: token-mean
+ loss_scale_factor: null
+ entropy_coeff: 0
+ calculate_entropy: false
+ use_kl_loss: false
+ use_prefix_grouper: false
+ use_torch_compile: true
+ kl_loss_coef: 0.001
+ kl_loss_type: low_var_kl
+ ppo_epochs: 1
+ shuffle: false
+ data_loader_seed: 42
+ checkpoint:
+ _target_: verl.trainer.config.CheckpointConfig
+ save_contents:
+ - model
+ - optimizer
+ - extra
+ load_contents: ${.save_contents}
+ async_save: false
+ mbridge_config: {}
+ use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
+ profiler:
+ _target_: verl.utils.profiler.ProfilerConfig
+ tool: ${oc.select:global_profiler.tool,null}
+ enable: false
+ all_ranks: false
+ ranks: []
+ save_path: ${oc.select:global_profiler.save_path,null}
+ tool_config:
+ nsys:
+ _target_: verl.utils.profiler.config.NsightToolConfig
+ discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
+ npu:
+ _target_: verl.utils.profiler.config.NPUToolConfig
+ contents: []
+ level: level0
+ analysis: true
+ discrete: false
+ torch:
+ _target_: verl.utils.profiler.config.TorchProfilerToolConfig
+ contents: []
+ discrete: false
+ torch_memory:
+ _target_: verl.utils.profiler.config.TorchMemoryToolConfig
+ trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
+ stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
+ router_replay:
+ _target_: verl.workers.config.RouterReplayConfig
+ mode: disabled
+ record_file: null
+ replay_file: null
+ ref:
+ optim:
+ _target_: verl.workers.config.TorchtitanOptimizerConfig
+ name: AdamW
+ lr: 0.001
+ lr_warmup_steps_ratio: 0.0
+ total_training_steps: -1
+ weight_decay: 0.01
+ lr_warmup_steps: -1
+ betas:
+ - 0.9
+ - 0.999
+ clip_grad: 1.0
+ eps: 1.0e-08
+ decay_type: linear
+ min_lr_factor: 0.0
+ torchtitan:
+ _target_: verl.workers.config.TorchtitanEngineConfig
+ param_offload: false
+ optimizer_offload: false
+ wrap_policy:
+ min_num_params: 0
+ reshard_after_forward: default
+ forward_prefetch: false
+ use_orig_params: false
+ mixed_precision: false
+ use_torch_compile: true
+ entropy_from_logits_with_chunking: false
+ entropy_checkpointing: false
+ data_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_size,1}
+ data_parallel_replicate_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_replicate_size,1}
+ data_parallel_shard_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_shard_size,1}
+ tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.tensor_parallel_size,1}
+ expert_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.expert_parallel_size,1}
+ pipeline_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.pipeline_parallel_size,1}
+ context_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.context_parallel_size,1}
+ attn_type: ${oc.select:actor_rollout_ref.actor.torchtitan.attn_type,flex}
+ strategy: torchtitan
+ seed: ${oc.select:actor_rollout_ref.actor.torchtitan.seed,42}
+ full_determinism: false
+ forward_only: true
+ dtype: bfloat16
+ rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
+ strategy: torchtitan
+ use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}
+ log_prob_micro_batch_size: null
+ log_prob_micro_batch_size_per_gpu: null
+ log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
+ log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
+ profiler:
+ _target_: verl.utils.profiler.ProfilerConfig
+ tool: ${oc.select:global_profiler.tool,null}
+ enable: false
+ all_ranks: false
+ ranks: []
+ save_path: ${oc.select:global_profiler.save_path,null}
+ tool_config:
+ nsys:
+ _target_: verl.utils.profiler.config.NsightToolConfig
+ discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
+ npu:
+ _target_: verl.utils.profiler.config.NPUToolConfig
+ contents: []
+ level: level0
+ analysis: true
+ discrete: false
+ torch:
+ _target_: verl.utils.profiler.config.TorchProfilerToolConfig
+ contents: []
+ discrete: false
+ torch_memory:
+ _target_: verl.utils.profiler.config.TorchMemoryToolConfig
+ trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
+ stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
+ router_replay:
+ _target_: verl.workers.config.RouterReplayConfig
+ mode: disabled
+ record_file: null
+ replay_file: null
+ _target_: verl.workers.config.TorchTitanActorConfig
+ rollout:
+ _target_: verl.workers.config.RolloutConfig
+ name: ???
+ mode: async
+ nnodes: 0
+ n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8}
+ temperature: 1.0
+ top_k: -1
+ top_p: 1
+ prompt_length: ${oc.select:data.max_prompt_length,512}
+ response_length: ${oc.select:data.max_response_length,512}
+ dtype: bfloat16
+ gpu_memory_utilization: 0.5
+ ignore_eos: false
+ enforce_eager: false
+ cudagraph_capture_sizes: null
+ free_cache_engine: true
+ tensor_model_parallel_size: 2
+ data_parallel_size: 1
+ expert_parallel_size: 1
+ pipeline_model_parallel_size: 1
+ max_num_batched_tokens: 8192
+ max_model_len: null
+ max_num_seqs: 1024
+ enable_chunked_prefill: true
+ enable_prefix_caching: true
+ logprobs_mode: processed_logprobs
+ scheduling_policy: fcfs
+ load_format: dummy
+ log_prob_micro_batch_size: null
+ log_prob_micro_batch_size_per_gpu: null
+ log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
+ log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
+ disable_log_stats: true
+ do_sample: true
+ 'n': 1
+ over_sample_rate: 0
+ multi_stage_wake_up: false
+ engine_kwargs:
+ vllm: {}
+ sglang: {}
+ trtllm: {}
+ val_kwargs:
+ _target_: verl.workers.config.SamplingConfig
+ top_k: -1
+ top_p: 1.0
+ temperature: 0
+ 'n': 1
+ do_sample: false
+ multi_turn:
+ _target_: verl.workers.config.MultiTurnConfig
+ enable: false
+ max_assistant_turns: null
+ tool_config_path: null
+ max_user_turns: null
+ max_parallel_calls: 1
+ max_tool_response_length: 256
+ tool_response_truncate_side: middle
+ interaction_config_path: null
+ use_inference_chat_template: false
+ tokenization_sanity_check_mode: strict
+ format: hermes
+ num_repeat_rollouts: null
+ calculate_log_probs: false
+ agent:
+ _target_: verl.workers.config.AgentLoopConfig
+ num_workers: 8
+ default_agent_loop: single_turn_agent
+ agent_loop_config_path: null
+ custom_async_server:
+ _target_: verl.workers.config.CustomAsyncServerConfig
+ path: null
+ name: null
+ checkpoint_engine:
+ _target_: verl.workers.config.CheckpointEngineConfig
+ backend: naive
+ update_weights_bucket_megabytes: 2048
+ engine_kwargs: {}
+ trace:
+ _target_: verl.workers.config.TraceConfig
+ project_name: ${oc.select:trainer.project_name,null}
+ experiment_name: ${oc.select:trainer.experiment_name,null}
+ backend: null
+ token2text: false
+ max_samples_per_step_per_worker: null
+ skip_rollout: false
+ skip_dump_dir: /tmp/rollout_dump
+ skip_tokenizer_init: true
+ enable_rollout_routing_replay: false
+ profiler:
+ _target_: verl.utils.profiler.ProfilerConfig
+ tool: ${oc.select:global_profiler.tool,null}
+ enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false}
+ all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false}
+ ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]}
+ save_path: ${oc.select:global_profiler.save_path,null}
+ tool_config:
+ npu:
+ _target_: verl.utils.profiler.config.NPUToolConfig
+ contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.contents,[]}
+ level: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.level,level0}
+ analysis: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.analysis,false}
+ discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.discrete,false}
+ torch:
+ _target_: verl.utils.profiler.config.TorchProfilerToolConfig
+ contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.contents,[]}
+ discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.discrete,false}
+ prometheus:
+ _target_: verl.workers.config.PrometheusConfig
+ enable: false
+ port: 9090
+ file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml
+ served_model_name: ${oc.select:actor_rollout_ref.model.path,null}
+ quantization: null
+ quantization_config_file: null
+ mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+ qat: ${oc.select:actor_rollout_ref.actor.qat,null}
+ layered_summon: false
+ model:
+ _target_: verl.workers.config.HFModelConfig
+ path: ~/models/deepseek-llm-7b-chat
+ hf_config_path: null
+ tokenizer_path: null
+ use_shm: false
+ trust_remote_code: false
+ custom_chat_template: null
+ external_lib: null
+ override_config: {}
+ enable_gradient_checkpointing: true
+ enable_activation_offload: false
+ use_remove_padding: true
+ lora_rank: 0
+ lora_alpha: 16
+ target_modules: all-linear
+ exclude_modules: null
+ lora_adapter_path: null
+ use_liger: false
+ use_fused_kernels: false
+ fused_kernel_options:
+ impl_backend: torch
+ tiled_mlp:
+ enabled: false
+ num_shards: 4
+ mtp:
+ _target_: verl.workers.config.MtpConfig
+ enable: false
+ enable_train: false
+ enable_rollout: false
+ detach_encoder: false
+ mtp_loss_scaling_factor: 0.1
+ speculative_algorithm: EAGLE
+ speculative_num_steps: 3
+ speculative_eagle_topk: 1
+ speculative_num_draft_tokens: 4
+ method: mtp
+ num_speculative_tokens: 1
+ hybrid_engine: true
+ nccl_timeout: 600
+data:
+ tokenizer: null
+ use_shm: false
+ train_files: ~/data/rlhf/gsm8k/train.parquet
+ val_files: ~/data/rlhf/gsm8k/test.parquet
+ train_max_samples: -1
+ val_max_samples: -1
+ prompt_key: prompt
+ reward_fn_key: data_source
+ max_prompt_length: 512
+ max_response_length: 512
+ train_batch_size: 1024
+ val_batch_size: null
+ tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path,
+ null}
+ return_raw_input_ids: false
+ return_raw_chat: true
+ return_full_prompt: false
+ shuffle: true
+ seed: null
+ dataloader_num_workers: 8
+ image_patch_size: 14
+ validation_shuffle: false
+ filter_overlong_prompts: false
+ filter_overlong_prompts_workers: 1
+ truncation: error
+ image_key: images
+ video_key: videos
+ trust_remote_code: false
+ custom_cls:
+ path: null
+ name: null
+ return_multi_modal_inputs: true
+ sampler:
+ class_path: null
+ class_name: null
+ datagen:
+ path: null
+ name: null
+ apply_chat_template_kwargs: {}
+critic:
+ optim:
+ _target_: verl.workers.config.TorchtitanOptimizerConfig
+ name: AdamW
+ lr: 1.0e-05
+ lr_warmup_steps_ratio: 0.0
+ total_training_steps: -1
+ weight_decay: 0.01
+ lr_warmup_steps: -1
+ betas:
+ - 0.9
+ - 0.999
+ clip_grad: 1.0
+ eps: 1.0e-08
+ decay_type: linear
+ min_lr_factor: 0.0
+ torchtitan:
+ _target_: verl.workers.config.TorchtitanEngineConfig
+ param_offload: false
+ optimizer_offload: false
+ wrap_policy:
+ min_num_params: 0
+ reshard_after_forward: default
+ forward_prefetch: false
+ use_orig_params: false
+ mixed_precision: false
+ use_torch_compile: true
+ entropy_from_logits_with_chunking: false
+ entropy_checkpointing: false
+ data_parallel_size: 1
+ data_parallel_replicate_size: 1
+ data_parallel_shard_size: 1
+ tensor_parallel_size: 1
+ expert_parallel_size: 1
+ pipeline_parallel_size: 1
+ context_parallel_size: 1
+ attn_type: flex
+ strategy: torchtitan
+ seed: 42
+ full_determinism: false
+ forward_only: false
+ dtype: bfloat16
+ _target_: verl.workers.config.TorchTitanCriticConfig
+ rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
+ strategy: torchtitan
+ enable: null
+ model:
+ path: ~/models/deepseek-llm-7b-chat
+ tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"}
+ override_config: {}
+ external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}
+ trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}
+ _target_: verl.trainer.config.BaseModelConfig
+ ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}
+ ppo_micro_batch_size: null
+ ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}
+ use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
+ ppo_max_token_len_per_gpu: 32768
+ forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}
+ ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1}
+ shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false}
+ data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null}
+ cliprange_value: 0.5
+ loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}
+ checkpoint:
+ _target_: verl.trainer.config.CheckpointConfig
+ save_contents:
+ - model
+ - optimizer
+ - extra
+ load_contents: ${.save_contents}
+ async_save: false
+ mbridge_config: {}
+ profiler:
+ _target_: verl.utils.profiler.ProfilerConfig
+ tool: ${oc.select:global_profiler.tool,null}
+ enable: false
+ all_ranks: false
+ ranks: []
+ save_path: ${oc.select:global_profiler.save_path,null}
+ tool_config:
+ nsys:
+ _target_: verl.utils.profiler.config.NsightToolConfig
+ discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
+ npu:
+ _target_: verl.utils.profiler.config.NPUToolConfig
+ contents: []
+ level: level0
+ analysis: true
+ discrete: false
+ torch:
+ _target_: verl.utils.profiler.config.TorchProfilerToolConfig
+ contents: []
+ discrete: false
+ torch_memory:
+ _target_: verl.utils.profiler.config.TorchMemoryToolConfig
+ trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
+ stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
+custom_reward_function:
+ path: null
+ name: null
+reward_model:
+ num_workers: null
+ reward_manager: null
+ enable: null
+ enable_resource_pool: null
+ n_gpus_per_node: null
+ nnodes: null
+ reward_loop_source: null
+ reward_loop_module_path: null
+ reward_loop_class_name: null
+ model:
+ path: null
+ external_lib: null
+ trust_remote_code: null
+ rollout:
+ name: null
+ dtype: null
+ gpu_memory_utilization: null
+ enforce_eager: null
+ cudagraph_capture_sizes: null
+ free_cache_engine: null
+ data_parallel_size: null
+ expert_parallel_size: null
+ tensor_model_parallel_size: null
+ max_num_batched_tokens: null
+ max_model_len: null
+ max_num_seqs: null
+ load_format: null
+ engine_kwargs: null
+ limit_images: null
+ enable_chunked_prefill: null
+ enable_prefix_caching: null
+ disable_log_stats: null
+ skip_tokenizer_init: null
+ prompt_length: null
+ response_length: null
+sandbox_fusion:
+ url: null
+ max_concurrent: null
+ memory_limit_mb: null
+reward:
+ num_workers: 8
+ custom_reward_function:
+ path: null
+ name: compute_score
+ reward_manager:
+ _target_: verl.workers.config.reward_model.RewardManagerConfig
+ source: register
+ name: naive
+ module:
+ _target_: verl.trainer.config.config.ModuleConfig
+ path: null
+ name: custom_reward_manager
+ reward_model:
+ enable: false
+ enable_resource_pool: false
+ n_gpus_per_node: 8
+ nnodes: 0
+ model_path: null
+ rollout:
+ _target_: verl.workers.config.RolloutConfig
+ name: ???
+ dtype: bfloat16
+ gpu_memory_utilization: 0.5
+ enforce_eager: true
+ cudagraph_capture_sizes: null
+ free_cache_engine: true
+ data_parallel_size: 1
+ expert_parallel_size: 1
+ tensor_model_parallel_size: 2
+ max_num_batched_tokens: 8192
+ max_model_len: null
+ max_num_seqs: 1024
+ load_format: auto
+ engine_kwargs: {}
+ limit_images: null
+ enable_chunked_prefill: true
+ enable_prefix_caching: true
+ disable_log_stats: true
+ skip_tokenizer_init: false
+ prompt_length: 2048
+ response_length: 2048
+ sandbox_fusion:
+ url: null
+ max_concurrent: 64
+ memory_limit_mb: 1024
+algorithm:
+ rollout_correction:
+ rollout_is: null
+ rollout_is_threshold: 2.0
+ rollout_rs: null
+ rollout_rs_threshold: null
+ bypass_mode: false
+ loss_type: ppo_clip
+ rollout_is_batch_normalize: false
+ _target_: verl.trainer.config.AlgoConfig
+ gamma: 1.0
+ lam: 1.0
+ adv_estimator: gae
+ norm_adv_by_std_in_grpo: true
+ use_kl_in_reward: false
+ kl_penalty: kl
+ kl_ctrl:
+ _target_: verl.trainer.config.KLControlConfig
+ type: fixed
+ kl_coef: 0.001
+ horizon: 10000
+ target_kl: 0.1
+ use_pf_ppo: false
+ pf_ppo:
+ reweight_method: pow
+ weight_pow: 2.0
+trainer:
+ balance_batch: true
+ total_epochs: 30
+ total_training_steps: null
+ project_name: verl_examples
+ experiment_name: gsm8k
+ logger:
+ - console
+ - wandb
+ log_val_generations: 0
+ rollout_data_dir: null
+ validation_data_dir: null
+ nnodes: 1
+ n_gpus_per_node: 8
+ save_freq: -1
+ esi_redundant_time: 0
+ resume_mode: auto
+ resume_from_path: null
+ val_before_train: true
+ val_only: false
+ test_freq: -1
+ critic_warmup: 0
+ default_hdfs_dir: null
+ del_local_ckpt_after_load: false
+ default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
+ max_actor_ckpt_to_keep: null
+ max_critic_ckpt_to_keep: null
+ ray_wait_register_center_timeout: 300
+ device: cuda
+ use_legacy_worker_impl: auto
+global_profiler:
+ _target_: verl.utils.profiler.ProfilerConfig
+ tool: null
+ steps: null
+ profile_continuous_steps: false
+ save_path: outputs/profile
+ global_tool_config:
+ nsys:
+ _target_: verl.utils.profiler.config.NsightToolConfig
+ discrete: false
+ controller_nsight_options:
+ trace: cuda,nvtx,cublas,ucx
+ cuda-memory-usage: 'true'
+ cuda-graph-trace: graph
+ worker_nsight_options:
+ trace: cuda,nvtx,cublas,ucx
+ cuda-memory-usage: 'true'
+ cuda-graph-trace: graph
+ capture-range: cudaProfilerApi
+ capture-range-end: null
+ kill: none
+ torch_memory:
+ trace_alloc_max_entries: 100000
+ stack_depth: 32
+ context: all
+ stacks: all
+ kw_args: {}
+transfer_queue:
+ enable: false
+ray_kwargs:
+ ray_init:
+ num_cpus: null
+ timeline_json_file: null
diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml
index 555cea354ae..1cdc21b1ec8 100644
--- a/verl/trainer/config/_generated_ppo_trainer.yaml
+++ b/verl/trainer/config/_generated_ppo_trainer.yaml
@@ -21,6 +21,7 @@ actor_rollout_ref:
min_lr_ratio: 0.0
num_cycles: 0.5
lr_scheduler_type: constant
+ zero_indexed_step: true
warmup_style: null
override_optimizer_config: null
fsdp_config:
@@ -126,6 +127,16 @@ actor_rollout_ref:
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
calculate_sum_pi_squared: false
sum_pi_squared_checkpointing: false
+ qat:
+ enable: false
+ mode: w4a16
+ group_size: 16
+ ignore_patterns:
+ - lm_head
+ - embed_tokens
+ - re:.*mlp.gate$
+ activation_observer: static_minmax
+ quantization_config_path: null
ref:
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
strategy: ${actor_rollout_ref.actor.strategy}
@@ -193,6 +204,8 @@ actor_rollout_ref:
_target_: verl.workers.config.RolloutConfig
name: ???
mode: async
+ nnodes: 0
+ n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8}
temperature: 1.0
top_k: -1
top_p: 1
@@ -267,6 +280,8 @@ actor_rollout_ref:
engine_kwargs: {}
trace:
_target_: verl.workers.config.TraceConfig
+ project_name: ${oc.select:trainer.project_name,null}
+ experiment_name: ${oc.select:trainer.experiment_name,null}
backend: null
token2text: false
max_samples_per_step_per_worker: null
@@ -301,6 +316,7 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+ qat: ${oc.select:actor_rollout_ref.actor.qat,null}
layered_summon: false
model:
_target_: verl.workers.config.HFModelConfig
@@ -399,6 +415,7 @@ critic:
min_lr_ratio: 0.0
num_cycles: 0.5
lr_scheduler_type: constant
+ zero_indexed_step: true
warmup_style: null
override_optimizer_config: null
model:
@@ -505,6 +522,10 @@ reward_model:
reward_loop_source: null
reward_loop_module_path: null
reward_loop_class_name: null
+ model:
+ path: null
+ external_lib: null
+ trust_remote_code: null
rollout:
name: null
dtype: null
diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml
index 80ffe2b68fc..ccaf6582902 100644
--- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml
+++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml
@@ -26,14 +26,9 @@ actor_rollout_ref:
_target_: verl.workers.config.VeOmniEngineConfig
param_offload: false
optimizer_offload: false
- data_parallel_size: 1
- data_parallel_replicate_size: 1
- data_parallel_shard_size: 1
- tensor_parallel_size: 1
- expert_parallel_size: 1
- pipeline_parallel_size: 1
- context_parallel_size: 1
+ fsdp_size: -1
ulysses_parallel_size: 1
+ expert_parallel_size: 1
mixed_precision: true
seed: 42
full_determinism: false
@@ -167,14 +162,9 @@ actor_rollout_ref:
_target_: verl.workers.config.VeOmniEngineConfig
param_offload: ${oc.select:actor_rollout_ref.actor.veomni.param_offload,False}
optimizer_offload: false
- data_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_size,1}
- data_parallel_replicate_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_replicate_size,1}
- data_parallel_shard_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_shard_size,1}
- tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.tensor_parallel_size,1}
- expert_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.expert_parallel_size,1}
- pipeline_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.pipeline_parallel_size,1}
- context_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.context_parallel_size,1}
+ fsdp_size: ${oc.select:actor_rollout_ref.actor.veomni.fsdp_size,-1}
ulysses_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.ulysses_parallel_size,1}
+ expert_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.expert_parallel_size,1}
mixed_precision: true
seed: ${oc.select:actor_rollout_ref.actor.veomni.seed,42}
full_determinism: false
@@ -196,6 +186,8 @@ actor_rollout_ref:
_target_: verl.workers.config.RolloutConfig
name: ???
mode: async
+ nnodes: 0
+ n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8}
temperature: 1.0
top_k: -1
top_p: 1
@@ -270,6 +262,8 @@ actor_rollout_ref:
engine_kwargs: {}
trace:
_target_: verl.workers.config.TraceConfig
+ project_name: ${oc.select:trainer.project_name,null}
+ experiment_name: ${oc.select:trainer.experiment_name,null}
backend: null
token2text: false
max_samples_per_step_per_worker: null
@@ -304,6 +298,7 @@ actor_rollout_ref:
quantization: null
quantization_config_file: null
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+ qat: ${oc.select:actor_rollout_ref.actor.qat,null}
layered_summon: false
model:
_target_: verl.workers.config.HFModelConfig
@@ -407,14 +402,9 @@ critic:
_target_: verl.workers.config.VeOmniEngineConfig
param_offload: false
optimizer_offload: false
- data_parallel_size: 1
- data_parallel_replicate_size: 1
- data_parallel_shard_size: 1
- tensor_parallel_size: 1
- expert_parallel_size: 1
- pipeline_parallel_size: 1
- context_parallel_size: 1
+ fsdp_size: -1
ulysses_parallel_size: 1
+ expert_parallel_size: 1
mixed_precision: true
seed: 42
full_determinism: false
@@ -500,6 +490,10 @@ reward_model:
reward_loop_source: null
reward_loop_module_path: null
reward_loop_class_name: null
+ model:
+ path: null
+ external_lib: null
+ trust_remote_code: null
rollout:
name: null
dtype: null
diff --git a/verl/trainer/config/actor/dp_actor.yaml b/verl/trainer/config/actor/dp_actor.yaml
index fc0a16be609..7fbe49c019e 100644
--- a/verl/trainer/config/actor/dp_actor.yaml
+++ b/verl/trainer/config/actor/dp_actor.yaml
@@ -48,3 +48,35 @@ calculate_sum_pi_squared: False
# Enable gradient checkpointing for sum_pi_squared computation (saves memory)
sum_pi_squared_checkpointing: False
+
+# QAT (Quantization-Aware Training) configuration
+# When enabled:
+# - QAT is automatically applied to actor model during training
+# - Fused scales (QKV/GateUp) are automatically enabled for training-inference consistency
+# - Fast quantization is used when syncing weights to vLLM rollout
+# Supported modes: "w4a16" (NVFP4 weight-only)
+# Note: "w4a4" mode is included in the code but currently has KL divergence issues and is NOT recommended for use.
+# For usage examples, see: https://github.com/verl-project/verl-recipe/blob/main/qat/README.md
+qat:
+
+ # Whether to enable QAT
+ enable: false
+
+ # Quantization mode: "w4a16" (weight-only). "w4a4" is experimental and not recommended.
+ mode: "w4a16"
+
+ # Quantization group size (NVFP4 requires 16)
+ group_size: 16
+
+ # Patterns to ignore (e.g., lm_head, embed_tokens)
+ ignore_patterns:
+
+ - "lm_head"
+ - "embed_tokens"
+ - "re:.*mlp.gate$"
+
+ # Activation observer for W4A4 mode: "static_minmax", "memoryless_minmax", or "minmax"
+ activation_observer: "static_minmax"
+
+ # Path to vLLM quantization config JSON file
+ quantization_config_path: null
diff --git a/verl/trainer/config/actor/torchtitan_actor.yaml b/verl/trainer/config/actor/torchtitan_actor.yaml
new file mode 100644
index 00000000000..3bf25b9a5a7
--- /dev/null
+++ b/verl/trainer/config/actor/torchtitan_actor.yaml
@@ -0,0 +1,16 @@
+# torchtitan actor config, inheriting from trainer/config/actor/actor.yaml
+defaults:
+ # torchtitan optimizer config
+ - ../optim@optim: torchtitan
+
+ # torchtitan engine config
+ - ../engine@torchtitan: torchtitan
+
+ - actor
+
+ # load the reference default config, then apply the fields in the current yaml
+ - _self_
+
+_target_: verl.workers.config.TorchTitanActorConfig
+
+strategy: torchtitan
diff --git a/verl/trainer/config/critic/torchtitan_critic.yaml b/verl/trainer/config/critic/torchtitan_critic.yaml
new file mode 100644
index 00000000000..4fafbd9d227
--- /dev/null
+++ b/verl/trainer/config/critic/torchtitan_critic.yaml
@@ -0,0 +1,28 @@
+# defaults specify the default config from each component
+defaults:
+
+ # torchtitan optimizer config
+ - ../optim@optim: torchtitan
+
+ # torchtitan engine config
+ - ../engine@torchtitan: torchtitan
+
+ # critic config, inheriting from trainer/config/critic/critic.yaml
+ - critic
+
+ # load the reference default config, then apply the fields in the current yaml
+ - _self_
+
+# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
+_target_: verl.workers.config.TorchTitanCriticConfig
+
+strategy: torchtitan
+
+# model config for the critic
+model:
+
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
+ _target_: verl.trainer.config.BaseModelConfig
+
+# seed for data loader
+data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null}
diff --git a/verl/trainer/config/engine/torchtitan.yaml b/verl/trainer/config/engine/torchtitan.yaml
new file mode 100644
index 00000000000..0f75004c0a4
--- /dev/null
+++ b/verl/trainer/config/engine/torchtitan.yaml
@@ -0,0 +1,74 @@
+# Target class for this configuration
+_target_: verl.workers.config.TorchtitanEngineConfig
+
+# Whether to offload model parameters to CPU
+param_offload: False
+
+# Whether to offload optimizer state to CPU
+optimizer_offload: False
+
+# policy for wrapping the model
+wrap_policy:
+ # Minimum number of parameters to trigger wrapping a layer with FSDP
+ min_num_params: 0
+
+# The policy for applying `reshard_after_forward` within an FSDP setup
+# Options: "default", "always", "never"
+reshard_after_forward: default
+
+# Prefetch the next forward-pass all-gather before the current forward computation.
+forward_prefetch: false
+
+# Whether to use original parameters
+use_orig_params: false
+
+# Mixed precision configuration for FSDP
+mixed_precision: false
+
+# Whether to use torch compile
+use_torch_compile: true
+
+# Whether to use entropy_from_logits_with_chunking
+entropy_from_logits_with_chunking: false
+
+# Whether to use entropy checkpointing
+entropy_checkpointing: false
+
+# Data parallel size (FSDP group size)
+data_parallel_size: 1
+
+# Data parallel replicate size
+data_parallel_replicate_size: 1
+
+# Data parallel shard size
+data_parallel_shard_size: 1
+
+# Tensor parallel size
+tensor_parallel_size: 1
+
+# Expert parallel size
+expert_parallel_size: 1
+
+# Pipeline parallel size
+pipeline_parallel_size: 1
+
+# Context parallel size
+context_parallel_size: 1
+
+# Attention type for torchtitan's model (e.g., "sdpa", "flex", "varlen")
+attn_type: flex
+
+# Strategy
+strategy: torchtitan
+
+# Random seed for reproducibility
+seed: 42
+
+# Whether to enable full determinism for distributed training, only for debugging
+full_determinism: false
+
+# Whether to use forward only
+forward_only: false
+
+# Mixed precision training param dtype
+dtype: bfloat16
diff --git a/verl/trainer/config/engine/veomni.yaml b/verl/trainer/config/engine/veomni.yaml
index 3bbbcf55e60..1869dd745f9 100644
--- a/verl/trainer/config/engine/veomni.yaml
+++ b/verl/trainer/config/engine/veomni.yaml
@@ -7,22 +7,13 @@ param_offload: False
# Whether to offload optimizer state to CPU
optimizer_offload: False
-data_parallel_size: 1
+# FSDP group size. -1 means use all available GPUs.
+fsdp_size: -1
-data_parallel_replicate_size: 1
-
-data_parallel_shard_size: 1
-
-tensor_parallel_size: 1
+ulysses_parallel_size: 1
expert_parallel_size: 1
-pipeline_parallel_size: 1
-
-context_parallel_size: 1
-
-ulysses_parallel_size: 1
-
mixed_precision: true
# Random seed for reproducibility.
diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml
deleted file mode 100644
index 478733339ce..00000000000
--- a/verl/trainer/config/generation.yaml
+++ /dev/null
@@ -1,62 +0,0 @@
-trainer:
- nnodes: 1
- n_gpus_per_node: 8
- device: cuda
-
-data:
- path: ~/data/rlhf/math/test.parquet
- prompt_key: prompt
- n_samples: 5
- output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet
- batch_size: 128
-
-model:
- path: ~/models/Qwen2-7B-Instruct
- external_lib: null
-rollout:
- _target_: verl.workers.config.RolloutConfig
- name: vllm
- # NOTE: 'sync' mode was removed in PR #4411. Only 'async' mode is supported.
- # WARNING: The main_generation.py workflow is currently broken for vLLM async rollout
- # as it requires synchronous generate_sequences() which vLLMAsyncRollout doesn't support.
- # See issue #4682 for discussion and workarounds.
- mode: async
- temperature: 1.0
- top_k: 50 # 0 for hf rollout, -1 for vllm rollout
- top_p: 0.7
- prompt_length: 1536
- response_length: 512
- # for vllm rollout
- dtype: bfloat16 # should align with FSDP
- gpu_memory_utilization: 0.5
- ignore_eos: False
- enforce_eager: True
- free_cache_engine: True
- load_format: auto
- tensor_model_parallel_size: 1
- data_parallel_size: 1
- max_num_batched_tokens: 8192
- max_model_len: null
- max_num_seqs: 1024
- log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
- log_prob_micro_batch_size_per_gpu: 8
- # for hf rollout
- do_sample: True
- disable_log_stats: True
- enable_chunked_prefill: True
- n: 1
- # support logging rollout prob for debugging purpose
- calculate_log_probs: False
-actor:
- strategy: fsdp # This is for backward-compatibility
- ulysses_sequence_parallel_size: 1 # sp size
- entropy_from_logits_with_chunking: False # calculate entropy with chunking to reduce memory peak
- entropy_checkpointing: False # recompute entropy
- fsdp_config:
- fsdp_size: -1
- forward_prefetch: False # FSDP1 forward_prefetch configuration
-
-ray_kwargs:
- ray_init:
- num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
- timeline_json_file: null
diff --git a/verl/trainer/config/legacy_reward_impl.yaml b/verl/trainer/config/legacy_reward_impl.yaml
index 476a6e2810f..e1266ba9a89 100644
--- a/verl/trainer/config/legacy_reward_impl.yaml
+++ b/verl/trainer/config/legacy_reward_impl.yaml
@@ -12,6 +12,10 @@ reward_model:
reward_loop_source: null
reward_loop_module_path: null
reward_loop_class_name: null
+ model:
+ path: null
+ external_lib: null
+ trust_remote_code: null
rollout:
name: null
dtype: null
diff --git a/verl/trainer/config/model_engine/torchtitan.yaml b/verl/trainer/config/model_engine/torchtitan.yaml
new file mode 100644
index 00000000000..32a4b3d0611
--- /dev/null
+++ b/verl/trainer/config/model_engine/torchtitan.yaml
@@ -0,0 +1,2 @@
+# @package _global_
+model_engine: torchtitan
diff --git a/verl/trainer/config/optim/fsdp.yaml b/verl/trainer/config/optim/fsdp.yaml
index a7dd99b1ee2..ce6ced773b6 100644
--- a/verl/trainer/config/optim/fsdp.yaml
+++ b/verl/trainer/config/optim/fsdp.yaml
@@ -38,6 +38,9 @@ num_cycles: 0.5
# LR scheduler type: "constant" or "cosine"
lr_scheduler_type: constant
+# Whether the LR schedule uses 0-indexed steps
+zero_indexed_step: true
+
# deprecated
warmup_style: null
diff --git a/verl/trainer/config/optim/torchtitan.yaml b/verl/trainer/config/optim/torchtitan.yaml
new file mode 100644
index 00000000000..baea31ee527
--- /dev/null
+++ b/verl/trainer/config/optim/torchtitan.yaml
@@ -0,0 +1,35 @@
+# Target class for this configuration
+_target_: verl.workers.config.TorchtitanOptimizerConfig
+
+# Optimizer name
+name: AdamW
+
+# Learning rate
+lr: 1e-3
+
+# LR warmup steps ratio
+lr_warmup_steps_ratio: 0.0
+
+# Total training steps
+total_training_steps: -1
+
+# Weight decay
+weight_decay: 0.01
+
+# LR warmup steps
+lr_warmup_steps: -1
+
+# Betas for Adam optimizer
+betas: [0.9, 0.999]
+
+# Clip gradient
+clip_grad: 1.0
+
+# Epsilon for Adam optimizer
+eps: 1e-8
+
+# Decay type: "linear", "sqrt", or "cosine"
+decay_type: linear
+
+# Minimum LR factor for cosine schedule
+min_lr_factor: 0.0
diff --git a/verl/trainer/config/ref/torchtitan_ref.yaml b/verl/trainer/config/ref/torchtitan_ref.yaml
new file mode 100644
index 00000000000..80cc01b1098
--- /dev/null
+++ b/verl/trainer/config/ref/torchtitan_ref.yaml
@@ -0,0 +1,28 @@
+# torchtitan ref config, inheriting from trainer/config/ref/ref.yaml
+defaults:
+ # torchtitan optimizer config
+ - ../optim@optim: torchtitan
+
+ # torchtitan engine config
+ - ../engine@torchtitan: torchtitan
+
+ - ref
+
+ # load the reference default config, then apply the fields in the current yaml
+ - _self_
+
+_target_: verl.workers.config.TorchTitanActorConfig
+
+strategy: torchtitan
+
+torchtitan:
+ seed: ${oc.select:actor_rollout_ref.actor.torchtitan.seed,42}
+ data_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_size,1}
+ data_parallel_replicate_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_replicate_size,1}
+ data_parallel_shard_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_shard_size,1}
+ tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.tensor_parallel_size,1}
+ expert_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.expert_parallel_size,1}
+ pipeline_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.pipeline_parallel_size,1}
+ context_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.context_parallel_size,1}
+ attn_type: ${oc.select:actor_rollout_ref.actor.torchtitan.attn_type,flex}
+ forward_only: True
diff --git a/verl/trainer/config/ref/veomni_ref.yaml b/verl/trainer/config/ref/veomni_ref.yaml
index f52421fd39f..738962356fe 100644
--- a/verl/trainer/config/ref/veomni_ref.yaml
+++ b/verl/trainer/config/ref/veomni_ref.yaml
@@ -14,14 +14,9 @@ strategy: veomni
veomni:
seed: ${oc.select:actor_rollout_ref.actor.veomni.seed,42}
- data_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_size,1}
- data_parallel_replicate_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_replicate_size,1}
- data_parallel_shard_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_shard_size,1}
- tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.tensor_parallel_size,1}
- expert_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.expert_parallel_size,1}
- pipeline_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.pipeline_parallel_size,1}
- context_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.context_parallel_size,1}
+ fsdp_size: ${oc.select:actor_rollout_ref.actor.veomni.fsdp_size,-1}
ulysses_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.ulysses_parallel_size,1}
+ expert_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.expert_parallel_size,1}
param_offload: ${oc.select:actor_rollout_ref.actor.veomni.param_offload,False}
attn_implementation: ${oc.select:actor_rollout_ref.actor.veomni.attn_implementation,flash_attention_2}
moe_implementation: ${oc.select:actor_rollout_ref.actor.veomni.moe_implementation,fused}
diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml
index 575c27551fa..894538d1d87 100644
--- a/verl/trainer/config/rollout/rollout.yaml
+++ b/verl/trainer/config/rollout/rollout.yaml
@@ -7,6 +7,12 @@ name: ???
# sync: LLM, async: AsyncLLM
mode: async
+# Number of nodes for standalone rollout server, must be > 0 in one-step-off/fully async training.
+nnodes: 0
+
+# Number of GPUs per node for rollout server.
+n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8}
+
# Sampling temperature for rollout.
temperature: 1.0
@@ -273,6 +279,12 @@ trace:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.TraceConfig
+ # Project name for experiment tracking (e.g., wandb)
+ project_name: ${oc.select:trainer.project_name,null}
+
+ # Experiment name for run identification in tracking tools
+ experiment_name: ${oc.select:trainer.experiment_name,null}
+
# trace backend, support mlflow, weave
backend: null
@@ -386,3 +398,6 @@ quantization_config_file: null
# MTP configuration, reuse model configuration
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
+
+# QAT configuration (inherited from actor.qat)
+qat: ${oc.select:actor_rollout_ref.actor.qat,null}
diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml
deleted file mode 100644
index b2308e39e44..00000000000
--- a/verl/trainer/config/sft_trainer.yaml
+++ /dev/null
@@ -1,91 +0,0 @@
-defaults:
- - optim: fsdp
- - _self_
-
-data:
- train_batch_size: 256
- micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
- micro_batch_size_per_gpu: 4 # this is also val batch size
- train_files: ~/data/gsm8k/train.parquet
- val_files: ~/data/gsm8k/test.parquet
- train_max_samples: -1 # set to -1 to use full dataset
- val_max_samples: -1 # set to -1 to use full dataset
- # Single-turn settings
- prompt_key: question
- response_key: answer
- prompt_dict_keys: null
- response_dict_keys: null
- # Multi-turn settings
- multiturn:
- enable: false # Set to true to use multi-turn dataset
- messages_key: messages # Key for messages list in multi-turn mode
- tools_key: tools # Key for tools list in multi-turn mode
- enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode
- max_length: 1024
- truncation: error
- balance_dp_token: False
- chat_template: null
- custom_cls:
- path: null
- name: null
- use_shm: False
- apply_chat_template_kwargs: {}
-model:
- partial_pretrain: ~/models/gemma-1.1-7b-it
- use_shm: False
- fsdp_config:
- model_dtype: fp32
- wrap_policy:
- min_num_params: 0
- cpu_offload: False
- offload_params: False
- external_lib: null
- enable_gradient_checkpointing: True
- trust_remote_code: False
- lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
- lora_alpha: 16 # LoRA scaling factor
- target_modules: all-linear # Target modules for LoRA adaptation
- use_liger: False
- strategy: fsdp2
-optim:
- lr: 1e-5
- betas: [0.9, 0.95]
- weight_decay: 0.01
- lr_warmup_steps_ratio: 0.1
- clip_grad: 1.0
- lr_scheduler: cosine
-ulysses_sequence_parallel_size: 1
-use_remove_padding: False
-trainer:
- default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
- default_hdfs_dir: null
- project_name: gsm8k-sft
- experiment_name: test
- total_epochs: 4
- total_training_steps: null
- logger: [ 'console', 'wandb' ]
- seed: 1
- save_freq: -1
- test_freq: -1
- nnodes: 1
- n_gpus_per_node: 8
- max_ckpt_to_keep: null # Maximum number of checkpoints to keep, set to null to keep all
-
- # Resume mode: "auto", "disable", or "resume_path"
- # "auto": resume from last checkpoint if available
- # "disable": start from scratch
- # "resume_path": resume from a user-defined path
- resume_mode: auto
-
- # Path to resume training from (used when resume_mode is "resume_path" or "auto")
- resume_from_path: null
-
- # Checkpoint configuration
- checkpoint:
- # What to include in saved checkpoints
- # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
- save_contents: ["model", "optimizer", "extra"]
-
- # For more flexibility, you can specify the contents to load from the checkpoint.
- load_contents: ${trainer.checkpoint.save_contents}
- device: cuda
diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py
deleted file mode 100644
index e8cc00e16c0..00000000000
--- a/verl/trainer/fsdp_sft_trainer.py
+++ /dev/null
@@ -1,873 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-A lightweight one-file FSDP SFT Trainer
-TODO(zhangchi.usc1992)
-- Add calculation of mfu
-- Add validation
-"""
-
-import os
-
-os.environ["NCCL_DEBUG"] = "WARN"
-os.environ["TOKENIZERS_PARALLELISM"] = "true"
-
-import logging
-import re
-import time
-from contextlib import nullcontext
-
-import hydra
-import torch
-import torch.distributed
-from omegaconf import DictConfig, OmegaConf
-from peft import LoraConfig, TaskType, get_peft_model
-from tensordict import TensorDict
-from torch import nn
-from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
-from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from torch.utils.data import Dataset, DistributedSampler
-from torchdata.stateful_dataloader import StatefulDataLoader
-from tqdm import tqdm
-from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
-
-import verl.utils.hdfs_io as hdfs_io
-from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input
-from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, get_checkpoint_tracker_filename
-from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
-from verl.utils.dataset import SFTDataset
-from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
-from verl.utils.device import (
- auto_set_device,
- get_device_id,
- get_device_name,
- is_cuda_available,
- is_npu_available,
-)
-from verl.utils.distributed import destroy_global_process_group, initialize_global_process_group
-from verl.utils.fs import copy_to_local
-from verl.utils.fsdp_utils import (
- CPUOffloadPolicy,
- MixedPrecisionPolicy,
- apply_fsdp2,
- fsdp2_clip_grad_norm_,
- fsdp2_load_full_state_dict,
- get_fsdp_wrap_policy,
- get_init_weight_context_manager,
- init_fn,
-)
-from verl.utils.logger import log_with_rank
-from verl.utils.profiler import log_gpu_memory_usage
-from verl.utils.py_functional import convert_to_regular_types
-from verl.utils.torch_dtypes import PrecisionType
-from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup
-from verl.utils.tracking import Tracking
-from verl.utils.ulysses import (
- gather_outputs_and_unpad,
- get_ulysses_sequence_parallel_world_size,
- ulysses_pad_and_slice_inputs,
-)
-from verl.workers.config.optimizer import build_optimizer
-from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
-
-logger = logging.getLogger(__file__)
-logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN"))
-
-
-def extract_step(path):
- match = re.search(r"global_step_(\d+)", path)
- if match:
- return int(match.group(1))
- return None
-
-
-class FSDPSFTTrainer:
- def __init__(
- self,
- config,
- device_mesh: DeviceMesh,
- ulysses_device_mesh: DeviceMesh,
- tokenizer,
- train_dataset: Dataset,
- val_dataset: Dataset,
- ):
- self.config = config
- self.device_mesh = device_mesh
- self.ulysses_device_mesh = ulysses_device_mesh
- self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
- self.tokenizer = tokenizer
- if self.config.data.chat_template is not None:
- raise ValueError("Apply Chat template from config is not supported yet.")
-
- # normalize dp size
- self._normalize_config_bsz()
-
- # Set sequence parallel size
- self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1)
- self.use_remove_padding = getattr(self.config, "use_remove_padding", False)
- if self.device_mesh.get_rank() == 0:
- print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}")
- print(f"Using remove padding: {self.use_remove_padding}")
-
- self._build_dataloader(train_dataset, val_dataset)
-
- self.lora = self.config.model.get("lora_adapter_path") is not None or self.config.model.lora_rank > 0
-
- # Initialize resume-related variables
- self.resume_global_step = 0
-
- # build model
- self._build_model_optimizer()
-
- # Initialize checkpoint manager
- self._init_checkpoint_manager()
-
- self.load_checkpoint()
-
- if self.device_mesh.get_rank() == 0:
- print(self.config)
-
- self.device_name = self.config.trainer.device
-
- def _normalize_config_bsz(self):
- dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0)
- if self.device_mesh.get_rank() == 0:
- print(f"Normalize batch size by dp {dp_size}")
-
- assert self.config.data.train_batch_size % dp_size == 0, (
- f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}"
- )
-
- self.config.data.train_batch_size //= dp_size
-
- assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0
-
- def _build_dataloader(self, train_dataset, val_dataset):
- # build dataset
- config = self.config
- self.train_dataset, self.val_dataset = train_dataset, val_dataset
-
- # build dataloader
- # Use data parallel rank and size instead of global rank and world size
-
- # If doing SP, we need to use the local rank and size
- if self.config.ulysses_sequence_parallel_size > 1:
- rank = self.ulysses_device_mesh.get_local_rank("dp")
- world_size = self.ulysses_device_mesh.size(0)
- if self.ulysses_device_mesh.get_rank() == 0:
- print(f"Using SP rank {rank} and size {world_size} for data distribution")
- print("Each SP rank gets different data, but the same data WITHIN the same rank")
- else:
- rank = self.device_mesh.get_rank()
- world_size = self.device_mesh.size()
- if self.device_mesh.get_rank() == 0:
- print(f"Using FSDP rank {rank} and size {world_size} for data distribution")
-
- # Set pin_memory_device when pin_memory is enabled.
- device_name = get_device_name()
-
- self.train_sampler = DistributedSampler(
- self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True
- )
- self.train_dataloader = StatefulDataLoader(
- dataset=self.train_dataset,
- batch_size=config.data.train_batch_size,
- sampler=self.train_sampler,
- num_workers=8,
- pin_memory=True,
- drop_last=True,
- pin_memory_device=device_name,
- )
-
- self.val_sampler = DistributedSampler(
- self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True
- )
- self.val_dataloader = StatefulDataLoader(
- dataset=self.val_dataset,
- batch_size=config.data.micro_batch_size_per_gpu,
- sampler=self.val_sampler,
- num_workers=8,
- pin_memory=True,
- drop_last=True,
- pin_memory_device=device_name,
- )
-
- def _build_model_optimizer(self):
- # TODO (zhangchi.usc1992):
- # 1. support pretrain from random weights
- # 2. support init directly from sharded weights
- local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)
-
- if self.config.model.get("external_lib", None) is not None:
- # This is used to import external_lib into the huggingface systems
- import importlib
-
- importlib.import_module(self.config.model.external_lib)
-
- log_gpu_memory_usage("Before model allocation", logger=logger)
-
- trust_remote_code = self.config.model.trust_remote_code
- torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
- torch_dtype = PrecisionType.to_dtype(torch_dtype)
- # load config first
- config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
- self.model_config = config
- if hasattr(self.model_config, "max_position_embeddings"):
- self.model_config.max_position_embeddings = max(
- self.model_config.max_position_embeddings, self.config.data.max_length
- )
- if self.config.ulysses_sequence_parallel_size > 1:
- assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled"
-
- # This may be very large
- init_context = get_init_weight_context_manager(
- use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh
- )
-
- with init_context():
- self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
- local_model_path,
- config=config,
- torch_dtype=torch_dtype,
- attn_implementation="flash_attention_2",
- trust_remote_code=trust_remote_code,
- )
-
- if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1:
- from verl.models.transformers.monkey_patch import apply_monkey_patch
-
- apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size)
-
- # Apply Liger kernel if use_liger is enabled
- if self.config.model.get("use_liger", False):
- from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
-
- _apply_liger_kernel_to_instance(model=self.model)
-
- if self.lora:
- self.model.enable_input_require_grads()
-
- lora_adapter_path = self.config.model.get("lora_adapter_path")
- if lora_adapter_path is not None:
- from peft import PeftModel
-
- print(f"Loading pre-trained LoRA adapter for sft from: {lora_adapter_path}")
-
- local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.use_shm)
-
- self.model = PeftModel.from_pretrained(self.model, local_adapter_path, is_trainable=True)
- peft_config = self.model.peft_config["default"]
- # Ensure task_type is TaskType enum, not string
- if isinstance(peft_config.task_type, str):
- peft_config.task_type = TaskType.CAUSAL_LM
- else:
- # Convert config to regular Python types before creating PEFT model
- lora_config = {
- "task_type": TaskType.CAUSAL_LM,
- "r": self.config.model.lora_rank,
- "lora_alpha": self.config.model.lora_alpha,
- "target_modules": convert_to_regular_types(self.config.model.target_modules),
- "bias": "none",
- }
- self.model = get_peft_model(self.model, LoraConfig(**lora_config))
- self.model = self.model.to(torch_dtype)
-
- if self.config.model.enable_gradient_checkpointing:
- self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
-
- log_gpu_memory_usage("After model allocation", logger=logger)
-
- mixed_precision = MixedPrecision(
- param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32
- )
-
- auto_wrap_policy = get_fsdp_wrap_policy(
- self.model,
- config=self.config.model.fsdp_config.wrap_policy,
- is_lora=self.lora,
- )
-
- if self.device_mesh.get_rank() == 0:
- print(auto_wrap_policy)
-
- if not self.config.model.fsdp_config.cpu_offload:
- cpu_offload = None
- else:
- cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)
-
- fsdp_strategy = self.config.model.strategy
- if fsdp_strategy == "fsdp":
- self.fsdp_model = FSDP(
- self.model,
- cpu_offload=cpu_offload,
- param_init_fn=init_fn,
- use_orig_params=False,
- auto_wrap_policy=auto_wrap_policy,
- device_id=get_device_id(),
- sharding_strategy=ShardingStrategy.FULL_SHARD,
- mixed_precision=mixed_precision,
- sync_module_states=True,
- device_mesh=self.device_mesh,
- forward_prefetch=False,
- )
- elif fsdp_strategy == "fsdp2":
- assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
- mp_policy = MixedPrecisionPolicy(
- param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True
- )
-
- fsdp_kwargs = {
- "mesh": self.device_mesh,
- "mp_policy": mp_policy,
- "offload_policy": cpu_offload,
- "reshard_after_forward": True,
- }
- full_state = self.model.state_dict()
- apply_fsdp2(self.model, fsdp_kwargs, self.config.model.fsdp_config)
- fsdp2_load_full_state_dict(self.model, full_state, self.device_mesh, cpu_offload)
- self.fsdp_model = self.model
- else:
- raise NotImplementedError(f"not implement {fsdp_strategy}")
-
- log_gpu_memory_usage("After FSDP wrapping", logger=logger)
-
- self.optimizer = build_optimizer(self.fsdp_model.parameters(), self.config.optim)
-
- log_gpu_memory_usage("After initialize optimizer", logger=logger)
-
- self.steps_per_epoch = len(self.train_dataloader)
- self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs
-
- if self.device_mesh.get_rank() == 0:
- print(
- f"Number of steps/epoch {self.steps_per_epoch}, number of epochs "
- f"{self.config.trainer.total_epochs}, total number of steps {self.total_steps}"
- )
-
- num_warmup_steps = int(self.total_steps * self.config.optim.lr_warmup_steps_ratio)
-
- if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine":
- self.lr_scheduler = get_cosine_schedule_with_warmup(
- optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps
- )
- elif self.config.optim.lr_scheduler == "wsd":
- self.lr_scheduler = get_wsd_schedule_with_warmup(
- optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps
- )
- else:
- raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}")
-
- def _compute_loss_and_backward(self, batch, do_backward=True, n_micro_batches=1):
- """Compute loss with optional sequence parallelism and remove padding features"""
- use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1
-
- # Move inputs to GPU and prepare loss mask
- input_ids = batch["input_ids"].to(self.device_name)
- attention_mask = batch["attention_mask"].to(self.device_name)
- position_ids = batch["position_ids"].to(self.device_name)
- loss_mask = batch.pop("loss_mask")[:, 1:].reshape(-1).to(self.device_name)
- loss_fct = nn.CrossEntropyLoss(reduction="none")
-
- # Context manager for sequence parallel if needed
- context = self.sharding_manager if use_sp else nullcontext()
- with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
- if not use_sp:
- # Standard forward pass without sequence parallel
- labels = input_ids[:, 1:].contiguous()
- output = self.fsdp_model(
- input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
- )
- logits = output.logits
-
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels.contiguous()
- # Flatten the tokens
- shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
- loss = loss * loss_mask.to(loss.device)
- else:
- # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks
- # i.e., each GPU has <1 sequence, and each SP group has 1 sequence
- # 1. All SP ranks will receive the *SAME* batch
- # 2. Different SP groups will receive *DIFFERENT* batches
- # This is implemented by the DistributedSampler
-
- batch_size, seqlen = input_ids.shape
- # Remove padding
- input_ids_rmpad, indices, *_ = unpad_input(
- input_ids.unsqueeze(-1), attention_mask
- ) # input_ids_rmpad (total_nnz, ...)
- input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
-
- # Unpad position_ids to align rotary
- position_ids_rmpad = index_first_axis(
- rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
- ).transpose(0, 1)
-
- # Pad and slice inputs for sequence parallelism
- input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
- input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()
- )
- # For computing loss
- input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
- input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
- input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size()
- )
- input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
-
- # Forward pass
- output = self.fsdp_model(
- input_ids=input_ids_rmpad_sliced,
- attention_mask=None, # Not needed with flash attention varlen
- position_ids=position_ids_rmpad_padded,
- use_cache=False,
- )
-
- # Compute loss locally then aggregate
- logits_rmpad = output.logits.squeeze(0)
- input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device)
- loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled)
- # Gather and unpad for sequence parallelism
- loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size)
-
- # This is the loss collected from all ulysses ranks
- full_loss = pad_input(
- hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen
- )
- full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss
- full_loss = full_loss.reshape(-1)
- loss_mask = loss_mask.to(full_loss.device)
- loss = full_loss * loss_mask
-
- valid_token_this_rank = torch.sum(loss_mask)
-
- if self.config.data.balance_dp_token:
- torch.distributed.all_reduce(valid_token_this_rank)
- dp_size = self.ulysses_device_mesh.size("dp") if use_sp else torch.distributed.get_world_size()
- else:
- dp_size = 1
-
- loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size
-
- loss = loss / n_micro_batches # normalize loss
-
- if do_backward:
- loss.backward()
- return loss
-
- def training_step(self, batch: TensorDict):
- start_time = time.time()
-
- self.fsdp_model.train()
-
- log_gpu_memory_usage("Before optimizer zero_grad", logger=logger)
-
- self.optimizer.zero_grad()
-
- log_gpu_memory_usage("After optimizer zero_grad", logger=logger)
-
- micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)
- n_micro_batches = len(micro_batches)
- step_loss = 0
- for micro_batch in micro_batches:
- loss = self._compute_loss_and_backward(batch=micro_batch, n_micro_batches=n_micro_batches)
- step_loss += loss.item()
-
- if self.config.model.strategy == "fsdp":
- grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
- elif self.config.model.strategy == "fsdp2":
- grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad)
- else:
- raise NotImplementedError(f"not implement {self.config.model.strategy}")
-
- log_gpu_memory_usage("Before optimizer step", logger=logger)
-
- # if grad_norm is not finite, skip the update
- if not torch.isfinite(grad_norm):
- print(f"WARN: grad_norm is not finite: {grad_norm}")
- self.optimizer.zero_grad()
- else:
- self.optimizer.step()
-
- log_gpu_memory_usage("After optimizer step", logger=logger)
-
- self.lr_scheduler.step()
-
- # reduce loss across dp ranks
- lr = self.lr_scheduler.get_last_lr()[0]
-
- log_gpu_memory_usage("After offload weights", logger=logger)
-
- step_loss = torch.tensor(step_loss).to(self.device_name)
-
- # compute time spent per step
- end_time = time.time()
- spend_time_per_step = end_time - start_time
-
- if is_cuda_available:
- torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)
- elif is_npu_available:
- torch.distributed.all_reduce(step_loss)
- step_loss /= self.device_mesh.size(0)
- return {
- "train/loss": step_loss.detach().item(),
- "train/lr(1e-3)": lr * 1e3,
- "train/time(s)": spend_time_per_step,
- }
-
- def validation_step(self, batch: TensorDict):
- self.fsdp_model.eval()
- with torch.no_grad():
- loss = self._compute_loss_and_backward(batch, do_backward=False)
- if is_cuda_available:
- torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
- elif is_npu_available:
- torch.distributed.all_reduce(loss)
- loss /= self.device_mesh.size(0)
- return loss
-
- def save_checkpoint(self, step):
- """Save checkpoint using FSDPCheckpointManager with improved tracking"""
- from verl.utils.fs import local_mkdir_safe
-
- # Determine checkpoint path
- local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}")
-
- if self.device_mesh.get_rank() == 0:
- print(f"Saving checkpoint to: {local_global_step_folder}")
-
- # Get max checkpoints to keep
- max_ckpt_to_keep = getattr(self.config.trainer, "max_ckpt_to_keep", None)
-
- # Use checkpoint manager to save
- self.checkpoint_manager.save_checkpoint(
- local_path=local_global_step_folder, global_step=step, max_ckpt_to_keep=max_ckpt_to_keep
- )
-
- # Save dataloader state
- if self.device_mesh.get_rank() == 0:
- local_mkdir_safe(local_global_step_folder)
- dataloader_local_path = os.path.join(local_global_step_folder, "data.pt")
-
- # Use StatefulDataLoader's built-in state dict functionality
- dataloader_state_dict = self.train_dataloader.state_dict()
- torch.save(dataloader_state_dict, dataloader_local_path)
- print(f"Saved dataloader state to: {dataloader_local_path}")
-
- # Update latest checkpoint tracker (atomic write)
- tracker_file = get_checkpoint_tracker_filename(self.config.trainer.default_local_dir)
- temp_tracker_file = tracker_file + ".tmp"
- with open(temp_tracker_file, "w") as f:
- f.write(str(step))
- os.rename(temp_tracker_file, tracker_file)
- print(f"Updated checkpoint tracker: {tracker_file}")
-
- # Copy to HDFS if configured
- if self.device_mesh.get_rank() == 0 and getattr(self.config.trainer, "default_hdfs_dir", None):
- hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
- hdfs_io.copy(src=local_global_step_folder, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
-
- torch.distributed.barrier()
-
- def _init_checkpoint_manager(self):
- """Initialize checkpoint manager with proper configuration"""
- # Get checkpoint configuration from config, with defaults
- checkpoint_config = getattr(self.config.trainer, "checkpoint", {})
-
- # Set default values if not specified
- save_contents = checkpoint_config.get("save_contents", ["model", "optimizer", "extra"])
- load_contents = checkpoint_config.get("load_contents", save_contents)
-
- # Create checkpoint config dict
- checkpoint_config_dict = {
- "load_contents": load_contents,
- "save_contents": save_contents,
- }
-
- # Convert to DictConfig for compatibility
- checkpoint_config_dict = DictConfig(checkpoint_config_dict)
-
- # Initialize checkpoint manager
- self.checkpoint_manager = FSDPCheckpointManager(
- model=self.fsdp_model,
- optimizer=self.optimizer,
- lr_scheduler=self.lr_scheduler,
- processing_class=self.tokenizer,
- checkpoint_config=checkpoint_config_dict,
- trust_remote_code=self.config.model.trust_remote_code,
- )
-
- def load_checkpoint(self):
- # Determine resume path based on configuration
- checkpoint_path = self._determine_resume_path()
-
- if checkpoint_path is None:
- return 0
-
- # extract resume step from checkpoint path
- resume_step = extract_step(checkpoint_path)
- if resume_step is None:
- log_with_rank(
- f"Warning: Could not extract step number from {checkpoint_path}, starting from step 0",
- logger=logger,
- rank=self.device_mesh.get_rank(),
- level=logging.WARNING,
- log_only_rank_0=True,
- )
- return 0
- self.resume_global_step = resume_step
-
- # Use checkpoint manager to load model state
- self.checkpoint_manager.load_checkpoint(checkpoint_path)
- log_with_rank(
- f"Successfully loaded model checkpoint from {checkpoint_path} (step {resume_step})",
- logger=logger,
- rank=self.device_mesh.get_rank(),
- log_only_rank_0=True,
- )
-
- # Always load dataloader state for StatefulDataLoader
- self._load_dataloader_state(checkpoint_path)
-
- return resume_step
-
- def _load_dataloader_state(self, checkpoint_path: str):
- """Load dataloader state from checkpoint"""
- dataloader_path = os.path.join(checkpoint_path, "data.pt")
-
- if os.path.exists(dataloader_path):
- # Use StatefulDataLoader's built-in state dict functionality
- dataloader_state_dict = torch.load(dataloader_path, map_location="cpu", weights_only=False)
- self.train_dataloader.load_state_dict(dataloader_state_dict)
-
- log_with_rank(
- f"Successfully loaded dataloader state from {dataloader_path}",
- logger=logger,
- rank=self.device_mesh.get_rank(),
- log_only_rank_0=True,
- )
-
- else:
- log_with_rank(
- f"Warning: No dataloader state found at {dataloader_path}, will start from scratch",
- logger=logger,
- rank=self.device_mesh.get_rank(),
- level=logging.WARNING,
- log_only_rank_0=True,
- )
-
- def _determine_resume_path(self):
- """Determine the path to resume from based on resume_mode configuration"""
- resume_mode = getattr(self.config.trainer, "resume_mode", "auto")
- resume_from_path = getattr(self.config.trainer, "resume_from_path", None)
-
- if resume_mode == "disable":
- return None
- elif resume_mode == "auto":
- if resume_from_path is not None:
- assert os.path.exists(resume_from_path), (
- "resume_from_path must be null or an existing path when resume_mode is 'auto'"
- )
- assert "global_step_" in resume_from_path, "resume_from_path must specify the global_steps"
- return resume_from_path
- # Try to find the latest checkpoint in the default directory
- return self._find_latest_checkpoint()
- elif resume_mode == "resume_path":
- assert os.path.exists(resume_from_path), (
- "resume_from_path must be an existing path when resume_mode is 'resume_path'"
- )
- assert "global_step_" in resume_from_path, "resume_from_path must specify the global_steps"
- return resume_from_path
- else:
- raise ValueError(f"Invalid resume_mode: {resume_mode}. Must be 'auto', 'disable', or 'resume_path'")
-
- def _find_latest_checkpoint(self):
- """Find the latest checkpoint in the default local directory"""
- checkpoint_dir = self.config.trainer.default_local_dir
-
- if not os.path.exists(checkpoint_dir):
- return None
-
- latest_checkpoint = find_latest_ckpt_path(checkpoint_dir)
-
- if latest_checkpoint and self.device_mesh.get_rank() == 0:
- step_num = extract_step(latest_checkpoint)
- print(f"Found latest checkpoint: {latest_checkpoint} (step {step_num})")
-
- return latest_checkpoint
-
- def fit(self):
- rank = self.device_mesh.get_rank()
-
- # TODO: add a unified tracking
- if rank == 0:
- tracking = Tracking(
- project_name=self.config.trainer.project_name,
- experiment_name=self.config.trainer.experiment_name,
- default_backend=self.config.trainer.logger,
- config=OmegaConf.to_container(self.config, resolve=True),
- )
-
- global_step = self.resume_global_step # Start from resumed step
- last_valid_metric = None
- # compute the total training steps.
- # the total training steps in SFT is mainly for early exit
- total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
-
- if self.config.trainer.total_training_steps is not None:
- total_training_steps = self.config.trainer.total_training_steps
-
- self.total_training_steps = total_training_steps
- log_with_rank(
- f"Total training steps: {self.total_training_steps},",
- logger=logger,
- rank=self.device_mesh.get_rank(),
- log_only_rank_0=True,
- )
-
- # With StatefulDataLoader, we don't need to manually calculate epochs and steps
- # The dataloader will automatically resume from where it left off
- if global_step > 0:
- log_with_rank(
- f"StatefulDataLoader will automatically resume from global step: {global_step}",
- logger=logger,
- rank=self.device_mesh.get_rank(),
- log_only_rank_0=True,
- )
-
- # Calculate which epoch we're starting from for sampler.set_epoch()
- start_epoch = global_step // self.steps_per_epoch
-
- train_time = 0
- for epoch in range(start_epoch, self.config.trainer.total_epochs):
- self.train_sampler.set_epoch(epoch=epoch)
-
- for step_in_epoch, data in enumerate(
- tqdm(
- self.train_dataloader,
- initial=global_step % self.steps_per_epoch if epoch == start_epoch else 0,
- total=self.steps_per_epoch,
- desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}",
- disable=rank != 0,
- )
- ):
- global_step += 1
- data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name)
- metric = self.training_step(data)
- train_time += metric["train/time(s)"]
- if rank == 0:
- tracking.log(data=metric, step=global_step)
-
- is_last_step = global_step >= self.total_training_steps
- is_valid_step = global_step % self.config.trainer.test_freq == 0
- is_save_step = global_step % self.config.trainer.save_freq == 0
-
- # early exit or validation step
- if is_last_step or (self.config.trainer.test_freq > 0 and is_valid_step):
- # Perform validation
- val_losses = []
- for val_data in self.val_dataloader:
- val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(
- self.device_name
- )
- val_loss = self.validation_step(val_data)
- val_losses.append(val_loss)
- if rank == 0:
- val_loss = torch.mean(torch.stack(val_losses))
- metric = {"val/loss": val_loss.detach().item()}
- tracking.log(data=metric, step=global_step)
- last_valid_metric = metric
- torch.distributed.barrier()
-
- if is_last_step or (self.config.trainer.save_freq > 0 and is_save_step):
- self.save_checkpoint(step=global_step)
-
- if is_last_step:
- if rank == 0:
- print(f"Total time for train steps: {train_time:.2f}s")
- print(f"Final validation metrics: {last_valid_metric}")
- return
-
-
-def run_sft(config):
- device_name = get_device_name()
- local_rank, rank, world_size = initialize_global_process_group()
-
- device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
- dp_size = world_size // config.ulysses_sequence_parallel_size
- ulysses_device_mesh = init_device_mesh(
- device_type=device_name,
- mesh_shape=(dp_size, config.ulysses_sequence_parallel_size),
- mesh_dim_names=("dp", "sp"),
- )
- # build tokenizer and datasets first
- from verl.utils import hf_tokenizer
-
- local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
- tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
- train_dataset = create_sft_dataset(
- config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1)
- )
- val_dataset = create_sft_dataset(
- config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1)
- )
-
- trainer = FSDPSFTTrainer(
- config=config,
- device_mesh=device_mesh,
- ulysses_device_mesh=ulysses_device_mesh,
- tokenizer=tokenizer,
- train_dataset=train_dataset,
- val_dataset=val_dataset,
- )
-
- trainer.fit()
-
- destroy_global_process_group()
-
-
-@hydra.main(config_path="config", config_name="sft_trainer", version_base=None)
-def main(config):
- # Automatically set `config.trainer.device = npu` when running on Ascend NPU.
- auto_set_device(config)
-
- run_sft(config)
-
-
-def create_sft_dataset(data_paths, data_config, tokenizer, max_samples=-1):
- """Create a dataset."""
- # build dataset
- # First check if a custom dataset class is specified
- if data_config.custom_cls.get("path", None):
- from verl.utils.import_utils import load_extern_object
-
- dataset_cls = load_extern_object(data_config.custom_cls.path, data_config.custom_cls.name)
- # Then check if multi-turn dataset should be used
- elif data_config.get("multiturn", {}).get("enable", False):
- dataset_cls = MultiTurnSFTDataset
- # Default to single-turn dataset
- else:
- dataset_cls = SFTDataset
-
- # Create datasets based on the selected class
- dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config, max_samples=max_samples)
- return dataset
-
-
-if __name__ == "__main__":
- main()
diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py
deleted file mode 100644
index 18aaa8cdbd0..00000000000
--- a/verl/trainer/main_generation.py
+++ /dev/null
@@ -1,154 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Generate responses given a dataset of prompts
-"""
-
-import os
-
-import hydra
-import numpy as np
-import ray
-
-os.environ["NCCL_DEBUG"] = "WARN"
-os.environ["TOKENIZERS_PARALLELISM"] = "true"
-# os.environ['TORCH_COMPILE_DISABLE'] = '1'
-
-from pprint import pprint
-
-import pandas as pd
-from omegaconf import OmegaConf
-
-from verl import DataProto
-from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
-from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
-from verl.utils import hf_tokenizer
-from verl.utils.fs import copy_to_local
-from verl.utils.hdfs_io import makedirs
-from verl.utils.model import compute_position_id_with_mask
-from verl.workers.fsdp_workers import ActorRolloutRefWorker
-
-
-@hydra.main(config_path="config", config_name="generation", version_base=None)
-def main(config):
- run_generation(config)
-
-
-def run_generation(config) -> None:
- if not ray.is_initialized():
- # this is for local ray cluster
- default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}
- ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
- runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
- runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
- ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
- print(f"ray init kwargs: {ray_init_kwargs}")
- ray.init(**OmegaConf.to_container(ray_init_kwargs))
-
- ray.get(main_task.remote(config))
-
-
-@ray.remote(num_cpus=1)
-def main_task(config):
- pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
- OmegaConf.resolve(config)
-
- local_path = copy_to_local(config.model.path)
- trust_remote_code = config.data.get("trust_remote_code", False)
- tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
-
- if config.rollout.temperature == 0.0:
- assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1."
- assert config.data.n_samples >= 1, "n_samples should always >= 1"
-
- # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
- dataset = pd.read_parquet(config.data.path)
- chat_lst = dataset[config.data.prompt_key].tolist()
-
- chat_lst = [chat.tolist() for chat in chat_lst]
-
- tokenizer.padding_side = "left"
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
-
- ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
- resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
-
- wg = RayWorkerGroup(
- resource_pool=resource_pool,
- ray_cls_with_init=ray_cls_with_init,
- device_name=config.trainer.device,
- )
- wg.init_model()
-
- total_samples = len(dataset)
- config_batch_size = config.data.batch_size
- apply_chat_template_kwargs = config.data.get("apply_chat_template_kwargs", {})
- num_batch = -(-total_samples // config_batch_size)
- output_lst = [[] for _ in range(config.data.n_samples)]
-
- for batch_idx in range(num_batch):
- print(f"[{batch_idx + 1}/{num_batch}] Start to process.")
- batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size]
- inputs = tokenizer.apply_chat_template(
- batch_chat_lst,
- add_generation_prompt=True,
- padding=True,
- truncation=True,
- max_length=config.rollout.prompt_length,
- return_tensors="pt",
- return_dict=True,
- tokenize=True,
- **apply_chat_template_kwargs,
- )
- input_ids = inputs["input_ids"]
- attention_mask = inputs["attention_mask"]
- position_ids = compute_position_id_with_mask(attention_mask)
- batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}
-
- data = DataProto.from_dict(batch_dict)
- data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)
-
- # START TO GENERATE FOR n_samples TIMES
- print(f"[{batch_idx + 1}/{num_batch}] Start to generate.")
- for n_sample in range(config.data.n_samples):
- output_padded = wg.generate_sequences(data_padded)
- output = unpad_dataproto(output_padded, pad_size=pad_size)
-
- output_texts = []
- for i in range(len(output)):
- data_item = output[i]
- prompt_length = data_item.batch["prompts"].shape[-1]
- valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
- valid_response_ids = data_item.batch["responses"][:valid_response_length]
- response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True)
- output_texts.append(response_str)
-
- output_lst[n_sample].extend(output_texts)
-
- # convert output_lst from (n_samples, n_data) to (n_data, n_sampels)
- output_lst = np.array(output_lst, dtype=object)
- output_lst = np.transpose(output_lst, axes=(1, 0)).tolist()
-
- # add to the data frame
- dataset["responses"] = output_lst
-
- # write to a new parquet
- output_dir = os.path.dirname(config.data.output_path)
- makedirs(output_dir, exist_ok=True)
- dataset.to_parquet(config.data.output_path)
-
-
-if __name__ == "__main__":
- main()
diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py
index b82414a2880..2c84374d245 100644
--- a/verl/trainer/main_ppo.py
+++ b/verl/trainer/main_ppo.py
@@ -162,8 +162,13 @@ def add_actor_rollout_worker(self, config):
actor_rollout_cls = AsyncActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup
- elif config.actor_rollout_ref.actor.strategy == "veomni":
- raise NotImplementedError("VeOmni does not support legacy worker implementation")
+ elif (
+ config.actor_rollout_ref.actor.strategy == "veomni"
+ or config.actor_rollout_ref.actor.strategy == "torchtitan"
+ ):
+ raise NotImplementedError(
+ f"{config.actor_rollout_ref.actor.strategy} does not support legacy worker implementation"
+ )
else:
raise NotImplementedError
@@ -191,14 +196,16 @@ def add_critic_worker(self, config):
# TODO: switch this to TrainingWorker as well
from verl.workers.megatron_workers import CriticWorker
- elif config.critic.strategy == "veomni":
+ elif config.critic.strategy == "veomni" or config.critic.strategy == "torchtitan":
if use_legacy_worker_impl == "disable":
from verl.workers.engine_workers import TrainingWorker
CriticWorker = TrainingWorker
- print("Using new worker implementation")
+ print(f"Using new worker implementation for {config.critic.strategy}")
else:
- raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
+ raise ValueError(
+ f"Invalid use_legacy_worker_impl for {config.critic.strategy}: {use_legacy_worker_impl}"
+ )
else:
raise NotImplementedError
diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py
index 2039fe56f62..b222dc1b705 100644
--- a/verl/trainer/ppo/core_algos.py
+++ b/verl/trainer/ppo/core_algos.py
@@ -763,7 +763,7 @@ def compute_optimal_token_baseline_advantage(
old_log_probs: torch.Tensor,
sum_pi_squared: torch.Tensor,
rollout_is_weights: torch.Tensor = None,
- handle_zero_tail: bool = False,
+ handle_zero_tail: bool = True,
epsilon: float = 1e-8,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -791,7 +791,7 @@ def compute_optimal_token_baseline_advantage(
None if not using IS
handle_zero_tail: If True, zero baselines will be set in the portion of the longest trajectory
that extends beyond the second-longest trajectory in the prompt group.
- Default: False
+ Default: True
epsilon: Small constant for numerical stability (default: 1e-8)
Returns:
@@ -1054,26 +1054,32 @@ def agg_loss(
"""
if loss_agg_mode == "token-mean":
if batch_num_tokens is None:
+ if dp_size > 1:
+ raise ValueError("(global) batch_num_tokens is required when dp_size > 1")
batch_num_tokens = loss_mask.sum()
loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size
- elif loss_agg_mode == "seq-mean-token-sum":
+ elif loss_agg_mode in ["seq-mean-token-sum", "seq-mean-token-sum-norm"]:
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
if global_batch_size is None:
+ if dp_size > 1:
+ raise ValueError("global_batch_size is required when dp_size > 1")
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
+ if loss_agg_mode == "seq-mean-token-sum-norm":
+ if loss_scale_factor is None:
+ horizon = loss_mask.shape[-1]
+ loss_scale_factor = horizon
+ loss /= loss_scale_factor
elif loss_agg_mode == "seq-mean-token-mean":
seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean
seq_mask = (seq_mask > 0).float() # exclude fully masked sequences
if global_batch_size is None:
+ if dp_size > 1:
+ raise ValueError("global_batch_size is required when dp_size > 1")
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
- elif loss_agg_mode == "seq-mean-token-sum-norm":
- seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
- if loss_scale_factor is None:
- loss_scale_factor = loss_mask.shape[-1]
- loss = torch.sum(seq_losses) / loss_scale_factor
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
@@ -1250,6 +1256,172 @@ def compute_policy_loss_vanilla(
return pg_loss, pg_metrics
+@register_policy_loss("dppo_tv")
+def compute_policy_loss_dppo_tv(
+ old_log_prob: torch.Tensor,
+ log_prob: torch.Tensor,
+ advantages: torch.Tensor,
+ response_mask: torch.Tensor,
+ loss_agg_mode: str = "token-mean",
+ config: Optional[ActorConfig] = None,
+ rollout_is_weights: torch.Tensor | None = None,
+) -> tuple[torch.Tensor, dict[str, Any]]:
+ """
+ Compute the clipped policy objective and related metrics for DPPO-Binary-TV.
+
+ See https://arxiv.org/pdf/2602.04879 for more details.
+
+ Args:
+ old_log_prob (torch.Tensor):
+ Log-probabilities of actions under the old policy, shape (batch_size, response_length).
+ log_prob (torch.Tensor):
+ Log-probabilities of actions under the current policy, shape (batch_size, response_length).
+ advantages (torch.Tensor):
+ Advantage estimates for each action, shape (batch_size, response_length).
+ response_mask (torch.Tensor):
+ Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
+ loss_agg_mode (str, optional):
+ Aggregation mode for `agg_loss`. Defaults to "token-mean".
+ config: `(verl.trainer.config.ActorConfig)`:
+ config for the actor.
+ rollout_log_probs: `(torch.Tensor)`:
+ log probabilities of actions under the rollout policy, shape (batch_size, response_length).
+ """
+
+ assert config is not None
+ assert not isinstance(config, AlgoConfig)
+ # Note: the clip_ratio is different from the standard PPO, it is the TV divergence threshold for DPPO.
+ clip_divergence = config.clip_ratio
+ clip_divergence_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_divergence
+ clip_divergence_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_divergence
+
+ negative_approx_kl = log_prob - old_log_prob
+ # Clamp negative_approx_kl for stability
+ negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
+ ratio = torch.exp(negative_approx_kl)
+ ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
+
+ # Instead of dual-clip PPO, we use truncated importance sampling (TIS) to clip the policy loss.
+ # However, a large threshold is recommended to avoid performance degradation due to the truncation bias.
+ # See Section 5.4 in https://arxiv.org/pdf/2602.04879 for more details.
+ clip_ratio_c = config.get("clip_ratio_c", 20.0)
+ truncated_ratio = torch.clamp(ratio, max=clip_ratio_c)
+ truncated_ratio = truncated_ratio.detach()
+
+ # Compute valid mask for DPPO-Binary-TV
+ prob = torch.exp(log_prob)
+ old_prob = torch.exp(old_log_prob)
+ valid_positive_mask = (prob - old_prob) <= clip_divergence_high
+ valid_negative_mask = (prob - old_prob) >= -clip_divergence_low
+ valid_mask = torch.where(advantages > 0, valid_positive_mask, valid_negative_mask)
+ valid_mask = valid_mask.detach().float()
+
+ pg_losses = -advantages * truncated_ratio * log_prob * valid_mask
+
+ # Apply rollout correction weights if provided
+ if rollout_is_weights is not None:
+ pg_losses = pg_losses * rollout_is_weights
+
+ pg_loss = agg_loss(
+ loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info
+ )
+
+ pg_clipfrac = verl_F.masked_mean((1.0 - valid_mask).float(), response_mask)
+ pg_clipfrac_lower = verl_F.masked_mean((ratio > clip_ratio_c).float() * valid_mask, response_mask)
+
+ pg_metrics = {
+ "actor/pg_clipfrac": pg_clipfrac.detach().item(),
+ "actor/ppo_kl": ppo_kl.detach().item(),
+ "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
+ }
+ return pg_loss, pg_metrics
+
+
+@register_policy_loss("dppo_kl")
+def compute_policy_loss_dppo_kl(
+ old_log_prob: torch.Tensor,
+ log_prob: torch.Tensor,
+ advantages: torch.Tensor,
+ response_mask: torch.Tensor,
+ loss_agg_mode: str = "token-mean",
+ config: Optional[ActorConfig] = None,
+ rollout_is_weights: torch.Tensor | None = None,
+) -> tuple[torch.Tensor, dict[str, Any]]:
+ """
+ Compute the clipped policy objective and related metrics for DPPO-Binary-KL.
+
+ See https://arxiv.org/pdf/2602.04879 for more details.
+
+ Args:
+ old_log_prob (torch.Tensor):
+ Log-probabilities of actions under the old policy, shape (batch_size, response_length).
+ log_prob (torch.Tensor):
+ Log-probabilities of actions under the current policy, shape (batch_size, response_length).
+ advantages (torch.Tensor):
+ Advantage estimates for each action, shape (batch_size, response_length).
+ response_mask (torch.Tensor):
+ Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
+ loss_agg_mode (str, optional):
+ Aggregation mode for `agg_loss`. Defaults to "token-mean".
+ config: `(verl.trainer.config.ActorConfig)`:
+ config for the actor.
+ rollout_log_probs: `(torch.Tensor)`:
+ log probabilities of actions under the rollout policy, shape (batch_size, response_length).
+ """
+
+ assert config is not None
+ assert not isinstance(config, AlgoConfig)
+ # Note: the clip_ratio is different from the standard PPO, it is the KL divergence threshold for DPPO.
+ clip_divergence = config.clip_ratio
+ clip_divergence_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_divergence
+ clip_divergence_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_divergence
+
+ negative_approx_kl = log_prob - old_log_prob
+ # Clamp negative_approx_kl for stability
+ negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
+ ratio = torch.exp(negative_approx_kl)
+ ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
+
+ # Instead of dual-clip PPO, we use truncated importance sampling (TIS) to clip the policy loss.
+ # However, a large threshold is recommended to avoid performance degradation due to the truncation bias.
+ # See Section 5.4 in https://arxiv.org/pdf/2602.04879 for more details.
+ clip_ratio_c = config.get("clip_ratio_c", 20.0)
+ truncated_ratio = torch.clamp(ratio, max=clip_ratio_c)
+ truncated_ratio = truncated_ratio.detach()
+
+ # Compute valid mask for DPPO-Binary-KL
+ prob = torch.exp(log_prob)
+ old_prob = torch.exp(old_log_prob)
+ binary_kl = old_prob * (old_log_prob - log_prob) + (1 - old_prob) * torch.log(
+ (1.0 - old_prob + 1e-8) / (1.0 - prob + 1e-8)
+ )
+ valid_positive_mask = (binary_kl <= clip_divergence_high) | (prob <= old_prob)
+ valid_negative_mask = (binary_kl <= clip_divergence_low) | (prob >= old_prob)
+ valid_mask = torch.where(advantages > 0, valid_positive_mask, valid_negative_mask)
+ valid_mask = valid_mask.detach().float()
+
+ pg_losses = -advantages * truncated_ratio * log_prob * valid_mask
+
+ # Apply rollout correction weights if provided
+ if rollout_is_weights is not None:
+ pg_losses = pg_losses * rollout_is_weights
+
+ pg_loss = agg_loss(
+ loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info
+ )
+
+ # For compatibility, return zero for pg_clipfrac_lower (not used in standard DPPO)
+ pg_clipfrac = verl_F.masked_mean((1.0 - valid_mask).float(), response_mask)
+ pg_clipfrac_lower = verl_F.masked_mean((ratio > clip_ratio_c).float() * valid_mask, response_mask)
+
+ pg_metrics = {
+ "actor/pg_clipfrac": pg_clipfrac.detach().item(),
+ "actor/ppo_kl": ppo_kl.detach().item(),
+ "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
+ }
+ return pg_loss, pg_metrics
+
+
@register_policy_loss("gspo")
def compute_policy_loss_gspo(
old_log_prob: torch.Tensor,
diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py
index 196e55969ce..ae43d2bad5c 100644
--- a/verl/trainer/ppo/ray_trainer.py
+++ b/verl/trainer/ppo/ray_trainer.py
@@ -303,6 +303,8 @@ def __init__(
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
+ self.checkpoint_manager = None
+
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
"""
Creates the train and validation dataloaders.
@@ -481,8 +483,7 @@ def _get_gen_batch(self, batch: DataProto) -> DataProto:
)
# For agent loop, we need reward model keys to compute score.
- if self.async_rollout_mode:
- gen_batch.non_tensor_batch.update(batch.non_tensor_batch)
+ gen_batch.non_tensor_batch.update(batch.non_tensor_batch)
return gen_batch
@@ -536,16 +537,9 @@ def _validate(self, merged: bool = False):
print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
# pad to be divisible by dp_size
- size_divisor = (
- self.actor_rollout_wg.world_size
- if not self.async_rollout_mode
- else self.config.actor_rollout_ref.rollout.agent.num_workers
- )
+ size_divisor = self.config.actor_rollout_ref.rollout.agent.num_workers
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
- if not self.async_rollout_mode:
- test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
- else:
- test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
+ test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
if self.use_rm and "rm_scores" not in test_output_gen_batch_padded.batch.keys():
# for colocate reward models, we need to sleep rollout model
@@ -765,6 +759,8 @@ def init_workers(self):
wg_kwargs["device_name"] = self.device_name
for resource_pool, class_dict in self.resource_pool_to_cls.items():
+ if not class_dict:
+ continue
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(
resource_pool=resource_pool,
@@ -835,15 +831,16 @@ def init_workers(self):
# if enable_agent_reward_loop, we directly pass reward_loop_workers to agent loop manager
# to stream reward computation with actor rollout
reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None
- self.async_rollout_manager = AgentLoopManager(
+ self.async_rollout_manager = AgentLoopManager.create(
config=self.config,
worker_group=self.actor_rollout_wg,
rollout_resource_pool=actor_rollout_resource_pool,
reward_loop_worker_handles=reward_loop_worker_handles,
)
+ checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine)
self.checkpoint_manager = CheckpointEngineManager(
- backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend,
+ config=checkpoint_engine_config,
trainer=self.actor_rollout_wg,
replicas=self.async_rollout_manager.rollout_replicas,
)
@@ -1258,7 +1255,7 @@ def fit(self):
return
if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
- rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
+ rollout_skip = RolloutSkip(self.config, self.async_rollout_manager)
rollout_skip.wrap_generate_sequences()
# add tqdm
@@ -1310,15 +1307,12 @@ def fit(self):
with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
- if not self.async_rollout_mode:
- gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
- else:
- if curr_step_profile:
- self.async_rollout_manager.start_profile()
- gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
- self.checkpoint_manager.sleep_replicas()
- if curr_step_profile:
- self.async_rollout_manager.stop_profile()
+ if curr_step_profile:
+ self.async_rollout_manager.start_profile()
+ gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
+ self.checkpoint_manager.sleep_replicas()
+ if curr_step_profile:
+ self.async_rollout_manager.stop_profile()
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
@@ -1327,15 +1321,12 @@ def fit(self):
with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
- if not self.async_rollout_mode:
- gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
- else:
- if curr_step_profile:
- self.async_rollout_manager.start_profile()
- gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
- self.checkpoint_manager.sleep_replicas()
- if curr_step_profile:
- self.async_rollout_manager.stop_profile()
+ if curr_step_profile:
+ self.async_rollout_manager.start_profile()
+ gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
+ self.checkpoint_manager.sleep_replicas()
+ if curr_step_profile:
+ self.async_rollout_manager.stop_profile()
batch = batch.union(gen_baseline_output)
# compute reward model score on batch
rm_scores = None
diff --git a/verl/trainer/sft_trainer.py b/verl/trainer/sft_trainer.py
index 979d92b04a1..d23ebc5fa90 100644
--- a/verl/trainer/sft_trainer.py
+++ b/verl/trainer/sft_trainer.py
@@ -238,8 +238,14 @@ def _get_batch_seqlens(self, data):
batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1)
batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp)
+ dp_group = self.engine.get_data_parallel_group()
+ dp_size = self.engine.get_data_parallel_size()
+
+ if dp_size == 1 or dp_group is None:
+ return batch_seqlens.tolist()
+
output_tensor = torch.empty(
- (batch_seqlens.shape[0] * self.engine.get_data_parallel_size(),),
+ (batch_seqlens.shape[0] * dp_size,),
dtype=batch_seqlens.dtype,
device=self.device_name,
) # (global_bsz,)
@@ -247,7 +253,7 @@ def _get_batch_seqlens(self, data):
torch.distributed.all_gather_into_tensor(
output_tensor=output_tensor,
input_tensor=batch_seqlens,
- group=self.engine.get_data_parallel_group(),
+ group=dp_group,
)
batch_seqlens = output_tensor.tolist()
@@ -372,9 +378,9 @@ def fit(self):
if self.engine.is_mp_src_rank_with_outputs():
val_loss = torch.mean(torch.tensor(val_losses, device=self.device_name))
# average over data parallel group
- torch.distributed.all_reduce(
- val_loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group()
- )
+ dp_group = self.engine.get_data_parallel_group()
+ if dp_group is not None:
+ torch.distributed.all_reduce(val_loss, op=torch.distributed.ReduceOp.AVG, group=dp_group)
if is_logging:
metric = {"val/loss": val_loss.detach().item()}
diff --git a/verl/utils/dataset/__init__.py b/verl/utils/dataset/__init__.py
index 6032d68c864..423bd9a9ccd 100644
--- a/verl/utils/dataset/__init__.py
+++ b/verl/utils/dataset/__init__.py
@@ -14,6 +14,5 @@
from .rl_dataset import RLHFDataset
from .rm_dataset import RMDataset
-from .sft_dataset import SFTDataset
-__all__ = ["RLHFDataset", "RMDataset", "SFTDataset"]
+__all__ = ["RLHFDataset", "RMDataset"]
diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py
index 9da33228e21..081d1dcfafa 100644
--- a/verl/utils/dataset/multiturn_sft_dataset.py
+++ b/verl/utils/dataset/multiturn_sft_dataset.py
@@ -36,6 +36,7 @@
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.dataset.vision_utils import process_image, process_video
from verl.utils.fs import copy_local_path_from_hdfs
+from verl.utils.py_functional import convert_nested_value_to_list_recursive
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@@ -68,19 +69,6 @@ def print_assembled_message(tokenizer, message_list, input_ids, loss_mask, attn_
logger.debug(str)
-def convert_nested_value_to_list_recursive(data_item):
- if isinstance(data_item, dict):
- return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()}
- elif isinstance(data_item, list):
- return [convert_nested_value_to_list_recursive(elem) for elem in data_item]
- elif isinstance(data_item, np.ndarray):
- # Convert to list, then recursively process the elements of the new list
- return convert_nested_value_to_list_recursive(data_item.tolist())
- else:
- # Base case: item is already a primitive type (int, str, float, bool, etc.)
- return data_item
-
-
class MultiTurnSFTDataset(Dataset):
"""
Dataset for multi-turn conversations where each assistant response should be trained
diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py
deleted file mode 100644
index 5fa8e07b252..00000000000
--- a/verl/utils/dataset/sft_dataset.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-SFT dataset
-- We assume user pass a single parquet file.
-- We load all the data into the memory.
-Each parquet file contains
-"""
-
-import numpy as np
-import pandas as pd
-import torch
-from omegaconf.listconfig import ListConfig
-from torch.utils.data import Dataset
-from transformers import PreTrainedTokenizer
-
-from verl.utils import hf_tokenizer
-from verl.utils.fs import copy_to_local
-from verl.utils.model import compute_position_id_with_mask
-
-
-class SFTDataset(Dataset):
- """
- This is an in-memory SFTDataset
-
- Arguments:
- config (OmegaConf): the data config
- """
-
- def __init__(self, parquet_files: str | ListConfig, tokenizer, config, max_samples: int = -1):
- prompt_key = config.get("prompt_key", "prompt")
- prompt_dict_keys = config.get("prompt_dict_keys", None)
- response_key = config.get("response_key", "response")
- response_dict_keys = config.get("response_dict_keys", None)
- max_length = config.get("max_length", 1024)
- truncation = config.get("truncation", "error")
- use_shm = config.get("use_shm", False)
- self.shuffle = config.get("shuffle", False)
- self.seed = config.get("seed")
- self.apply_chat_template_kwargs = config.get("apply_chat_template_kwargs", {})
-
- assert truncation in ["error", "left", "right"]
- self.truncation = truncation
- self.use_shm = use_shm
-
- if not isinstance(parquet_files, ListConfig):
- parquet_files = [parquet_files]
-
- self.parquet_files = parquet_files
- self.max_samples = max_samples
- if isinstance(tokenizer, str):
- tokenizer = hf_tokenizer(tokenizer)
- self.tokenizer: PreTrainedTokenizer = tokenizer
-
- self.prompt_key = prompt_key if isinstance(prompt_key, tuple | list) else [prompt_key]
- self.response_key = response_key if isinstance(response_key, tuple | list) else [response_key]
- self.prompt_dict_keys = prompt_dict_keys if prompt_dict_keys else []
- self.response_dict_keys = response_dict_keys if response_dict_keys else []
-
- self.max_length = max_length
-
- self._download()
- self._read_files_and_tokenize()
-
- def _download(self):
- for i, parquet_file in enumerate(self.parquet_files):
- self.parquet_files[i] = copy_to_local(parquet_file, verbose=True, use_shm=self.use_shm)
-
- def _read_files_and_tokenize(self):
- def series_to_item(ls):
- import numpy
- import pandas
-
- while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1:
- ls = ls[0]
- return ls
-
- dataframes = []
- for parquet_file in self.parquet_files:
- # read parquet files and cache
- dataframe = pd.read_parquet(parquet_file)
- dataframes.append(dataframe)
- self.dataframe = pd.concat(dataframes)
-
- total = len(self.dataframe)
- print(f"dataset len: {len(self.dataframe)}")
-
- if self.max_samples > 0 and self.max_samples < total:
- if self.shuffle:
- rngs_args = (self.seed,) if self.seed is not None else ()
- rng = np.random.default_rng(*rngs_args)
- indices = rng.choice(total, size=self.max_samples, replace=False)
- else:
- indices = np.arange(self.max_samples)
- self.dataframe = self.dataframe.iloc[indices.tolist()]
- print(f"selected {self.max_samples} random samples out of {total}")
-
- self.prompts = self.dataframe[self.prompt_key]
- for key in self.prompt_dict_keys:
- # type(x): pandas.core.series.Series
- # type(x[0]): numpy.ndarray
- # type(x[0][0]): dict
- try:
- self.prompts = self.prompts.apply(lambda x: series_to_item(x)[key], axis=1) # noqa: B023
- except Exception:
- print(f"self.prompts={self.prompts}")
- raise
- if isinstance(self.prompts, pd.DataFrame):
- self.prompts = self.prompts.squeeze()
- self.prompts = self.prompts.tolist()
- self.responses = self.dataframe[self.response_key]
- for key in self.response_dict_keys:
- try:
- self.responses = self.responses.apply(lambda x: series_to_item(x)[key], axis=1) # noqa: B023
- except Exception:
- print(f"self.responses={self.responses}")
- raise
- if isinstance(self.responses, pd.DataFrame):
- self.responses = self.responses.squeeze()
- self.responses = self.responses.tolist()
-
- def __len__(self):
- return len(self.prompts)
-
- def __getitem__(self, item):
- tokenizer = self.tokenizer
-
- prompt = self.prompts[item]
- response = self.responses[item]
-
- # apply chat template
- prompt_chat = [{"role": "user", "content": prompt}]
-
- # string
- prompt_chat_str = tokenizer.apply_chat_template(
- prompt_chat, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
- )
- response_chat_str = response + tokenizer.eos_token
-
- # tokenize
- prompt_ids_output = tokenizer(prompt_chat_str, return_tensors="pt", add_special_tokens=False)
- prompt_ids = prompt_ids_output["input_ids"][0]
- prompt_attention_mask = prompt_ids_output["attention_mask"][0]
-
- response_ids_output = tokenizer(response_chat_str, return_tensors="pt", add_special_tokens=False)
- response_ids = response_ids_output["input_ids"][0]
- response_attention_mask = response_ids_output["attention_mask"][0]
-
- prompt_length = prompt_ids.shape[0]
- response_length = response_ids.shape[0]
-
- input_ids = torch.cat((prompt_ids, response_ids), dim=-1)
- attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
-
- # padding to max length
- sequence_length = input_ids.shape[0]
- if sequence_length < self.max_length:
- padded_input_ids = (
- torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype)
- * self.tokenizer.pad_token_id
- )
- padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype)
-
- input_ids = torch.cat((input_ids, padded_input_ids))
- attention_mask = torch.cat((attention_mask, padded_attention_mask))
- elif sequence_length > self.max_length:
- if self.truncation == "left":
- # actually, left truncation may not be reasonable
- input_ids = input_ids[-self.max_length :]
- attention_mask = attention_mask[-self.max_length :]
- elif self.truncation == "right":
- input_ids = input_ids[: self.max_length]
- attention_mask = attention_mask[: self.max_length]
- elif self.truncation == "error":
- raise NotImplementedError(f"{sequence_length=} is larger than {self.max_length=}")
- else:
- raise NotImplementedError(f"Unknown truncation method {self.truncation}")
-
- position_ids = compute_position_id_with_mask(attention_mask)
-
- loss_mask = attention_mask.clone()
- if prompt_length > 1:
- # mask out prompt for SFT.
- loss_mask[: min(prompt_length, loss_mask.size(0)) - 1] = 0
- # mask out the last token in response
- loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0
-
- return {
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "position_ids": position_ids,
- "loss_mask": loss_mask,
- }
diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py
index 70f8cf600ee..8f7f8bef0d0 100644
--- a/verl/utils/fsdp_utils.py
+++ b/verl/utils/fsdp_utils.py
@@ -144,7 +144,7 @@ def lambda_policy_fn(module):
@torch.no_grad()
def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):
- if fsdp_version(model) == 2:
+ if fsdp_version(model) == 2 or fsdp_version(model) == 0:
offload_fsdp2_model_to_cpu(model, empty_cache)
return
@@ -178,7 +178,7 @@ def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True):
@torch.no_grad()
def load_fsdp_model_to_gpu(model: FSDP):
- if fsdp_version(model) == 2:
+ if fsdp_version(model) == 2 or fsdp_version(model) == 0:
load_fsdp2_model_to_gpu(model)
return
@@ -438,7 +438,7 @@ def get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True
):
state_dict = model.state_dict()
return state_dict
- elif fsdp_version(model) == 2:
+ elif fsdp_version(model) == 2 or fsdp_version(model) == 0:
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
state_dict_config = StateDictOptions(
diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py
index 40c6db492c9..3eca28cb6fd 100644
--- a/verl/utils/kernel/kernels.py
+++ b/verl/utils/kernel/kernels.py
@@ -263,6 +263,7 @@ def efficient_entropy_kernel_general_mainloop(
_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
_logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
+ vocab_bound = min((pid_n + 1) * vocab_per_split, vocab_size)
for n in range(0, num_pid_n):
start_offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N
offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N)
@@ -308,12 +309,14 @@ def efficient_entropy_kernel_general_mainloop(
# scale logits by temperature
logits *= rcp_temperature
+ logits_for_lse = tl.where(offs_bn[None, :] < vocab_bound, logits, float("-inf"))
+
# update global maximum
_max_old = _max
- m_pid_n = tl.max(logits, axis=1)
+ m_pid_n = tl.max(logits_for_lse, axis=1)
_max = tl.maximum(_max_old, m_pid_n)
- exp_logits = tl.exp(logits - _max[:, None])
+ exp_logits = tl.exp(logits_for_lse - _max[:, None])
coeff = tl.exp(_max_old - _max)
_accu = coeff * _accu + tl.sum(exp_logits, axis=1)
diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py
index 9572fb91962..708c6e24fa5 100644
--- a/verl/utils/megatron_utils.py
+++ b/verl/utils/megatron_utils.py
@@ -269,6 +269,8 @@ def peft_pre_wrap_hook(model):
model = provider.provide_distributed_model(
wrap_with_ddp=wrap_config.wrap_with_ddp,
ddp_config=ddp_config,
+ fp16=provider.fp16,
+ bf16=provider.bf16,
)
# Extract TransformerConfig from the created model
@@ -440,6 +442,11 @@ def offload_megatron_model_to_cpu(models):
# if the grad_data size is already zero, we assume that it is already offloaded
buffer.grad_data_size = buffer.grad_data.storage().size()
buffer.grad_data.storage().resize_(0)
+ # Offload frozen parameters not in DDP buffers (e.g. base model in LoRA/PEFT)
+ # DDP buffers only contain requires_grad=True params, so frozen params must be offloaded separately.
+ for param in model_chunk.module.parameters():
+ if not param.requires_grad and param.device.type != "cpu":
+ param.data = param.data.to("cpu", non_blocking=True)
else:
# we need this for ref module
for _, param in model_chunk.named_parameters():
@@ -451,7 +458,14 @@ def offload_megatron_model_to_cpu(models):
@torch.no_grad()
-def load_megatron_model_to_gpu(models, load_grad=True):
+def load_megatron_model_to_gpu(models, load_grad=True, load_frozen_params=True):
+ """
+ Load megatron model to GPU.
+ Args:
+ models: The model to load.
+ load_grad: Whether to load gradients.
+ load_frozen_params: Whether to load frozen parameters.
+ """
for model_chunk in models:
if isinstance(model_chunk, DDP):
model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]
@@ -466,6 +480,13 @@ def load_megatron_model_to_gpu(models, load_grad=True):
buffer.param_data.storage().resize_(buffer.param_data_size)
# copy data from cpu to cuda
buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True)
+
+ # Load frozen parameters that were offloaded (e.g. base model in LoRA/PEFT)
+ if load_frozen_params:
+ device_id = get_device_id()
+ for param in model_chunk.module.parameters():
+ if not param.requires_grad and param.device.type == "cpu":
+ param.data = param.data.to(device_id, non_blocking=True)
else:
# we need this for ref module
device_id = get_device_id()
diff --git a/verl/utils/memory_buffer.py b/verl/utils/memory_buffer.py
deleted file mode 100644
index 9386f0d88bc..00000000000
--- a/verl/utils/memory_buffer.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-This file contains utilities to manipulate torch memory buffers
-"""
-
-from typing import Optional
-
-import torch
-from torch import nn
-
-from verl.utils.device import get_device_name
-
-
-class MemoryBuffer:
- """
- A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying
- memory. It must have a unique type to support this behavior.
- """
-
- def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Optional[torch.Tensor] = None):
- self.numel = numel
- self.numel_padded = numel_padded
- self.dtype = dtype
- if source is not None:
- self.data = source
- else:
- self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_name(), requires_grad=False)
-
- def zero(self):
- """Reset the buffer to zero."""
- self.data.zero_()
-
- def get(self, shape, start_index):
- """Return a tensor with the input `shape` as a view into the
- 1-D data starting at `start_index`."""
- end_index = start_index + shape.numel()
- assert end_index <= self.numel, "requested tensor is out of the buffer range."
- buffer_tensor = self.data[start_index:end_index]
- buffer_tensor = buffer_tensor.view(shape)
- return buffer_tensor
-
-
-def calc_padded_numel(shape: torch.Size, dtype: torch.dtype):
- """for cuda memory alignment, make sure alignment by 128-bits"""
- align_numel = 128 // torch.finfo(dtype).bits
- numel = shape.numel()
- return (numel + align_numel - 1) // align_numel * align_numel
-
-
-def get_weight_buffer_meta_from_module(module: nn.Module) -> dict[str, dict]:
- """
- Return a dictionary containing name to a shape and dtype.
- """
- weight_buffer_meta = {}
- for name, param in sorted(module.named_parameters()):
- weight_buffer_meta[name] = {"shape": param.shape, "dtype": param.dtype}
- return weight_buffer_meta
-
-
-def build_memory_buffer(weight_buffer_meta: dict[str, dict]) -> dict[torch.dtype, MemoryBuffer]:
- """Build the memory buffer given weight_buffer_meta
-
- Args:
- weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors
-
- Returns: a large memory buffer for each dtype that can hold all the tensors
-
- """
- memory_buffers = {}
- total_numel_map = {} # map from dtype to the total numel
- for name, meta_info in sorted(weight_buffer_meta.items()):
- shape = meta_info["shape"]
- dtype = meta_info["dtype"]
-
- assert isinstance(shape, torch.Size)
- assert isinstance(dtype, torch.dtype)
-
- if dtype not in total_numel_map:
- total_numel_map[dtype] = 0
-
- total_numel_map[dtype] += calc_padded_numel(shape, dtype)
-
- for dtype, total_numel in total_numel_map.items():
- memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype)
-
- return memory_buffers
-
-
-def build_memory_reference_from_module(
- module: torch.nn.Module, memory_buffers: dict[torch.dtype, MemoryBuffer], maintain_weight=True
-):
- start_index = {}
- for dtype in memory_buffers:
- start_index[dtype] = 0
- for name, param in sorted(module.named_parameters()):
- memory_buffer = memory_buffers[param.dtype]
- buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype])
- # need to increment start_index
- start_index[param.dtype] += calc_padded_numel(param.shape, param.dtype)
- if maintain_weight:
- buffer.copy_(param.data)
- param.data = buffer
-
-
-def build_memory_reference(weight_buffer_meta: dict[str, dict], memory_buffers: dict[torch.dtype, MemoryBuffer]):
- """Build the memory references. The memory buffers are built using the build_memory_buffer API.
- This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta.
-
- Args:
- weight_buffer_meta:
- memory_buffers:
-
- Returns:
-
- """
- start_idx = {}
- weight_buffers = {}
- for dtype in memory_buffers:
- start_idx[dtype] = 0
-
- for name, meta_info in sorted(weight_buffer_meta.items()):
- shape = meta_info["shape"]
- dtype = meta_info["dtype"]
-
- buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype])
- start_idx[dtype] += calc_padded_numel(shape, dtype)
- weight_buffers[name] = buffer
-
- return weight_buffers
-
-
-class MemoryBufferModuleWrapper:
- """
- Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to
- - It will change the checkpoint name
- """
-
- def __init__(self, module: nn.Module):
- super().__init__()
- self.module = module
- self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module)
- self.memory_buffers = build_memory_buffer(self.weight_buffer_meta)
- build_memory_reference_from_module(self.module, self.memory_buffers)
-
- def get_memory_buffers(self):
- return self.memory_buffers
-
- def get_weight_buffer_meta(self):
- return self.weight_buffer_meta
-
-
-class MegatronMemoryBufferForRollout:
- """
- We assume that
- - inference engine has tp + dp
- - actor has tp + pp + dp
- - the tp between inference engine and actor should be the same
- - memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer
- - weight_buffers: contains a list of weight_buffers, each is a dict from name to param
- - named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that
- the named_parameters may not be directly compatible with inference engine. User has to take care of
- this part such as the layout mismatches. (e.g. qkv transpose)
- - Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory.
- - When doing weight sync, the data is transfer via memory buffers
- """
-
- def __init__(self, transform_memory_param_fn):
- self._memory_buffers = []
- self._weight_buffers = []
- self._named_parameters = {}
- self.transform_memory_param_fn = transform_memory_param_fn
-
- def initialize_weight_buffer(self, weight_buffer_meta_pp: list[dict[str, dict]]):
- """
- Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct
- a large buffer for each dtype in the weight_buffer.
-
- Args:
- weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from
-
- Returns: None
-
- """
- self.weight_buffer_meta_pp = weight_buffer_meta_pp
-
- for weight_buffer_meta in self.weight_buffer_meta_pp:
- memory_buffer = build_memory_buffer(weight_buffer_meta)
- self._memory_buffers.append(memory_buffer)
- self._weight_buffers.append(None)
-
- def build_memory_reference(self):
- for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp):
- self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i])
- self._named_parameters = self.transform_memory_param_fn(self._weight_buffers)
-
- @property
- def named_parameters(self):
- return self._named_parameters
-
- @property
- def weight_buffers(self):
- return self._weight_buffers
-
- @property
- def memory_buffers(self):
- return self._memory_buffers
diff --git a/verl/utils/net_utils.py b/verl/utils/net_utils.py
index 1acef76a434..0c6ee35d118 100644
--- a/verl/utils/net_utils.py
+++ b/verl/utils/net_utils.py
@@ -70,15 +70,22 @@ def is_valid_ipv6_address(address: str) -> bool:
return False
-def get_free_port(address: str) -> tuple[int, socket.socket]:
- family = socket.AF_INET
- if is_valid_ipv6_address(address):
- family = socket.AF_INET6
+def get_free_port(address: str, with_alive_sock: bool = False) -> tuple[int, socket.socket | None]:
+ """Find a free port on the given address.
+
+ By default the socket is closed internally, suitable for immediate use.
+ Set with_alive_sock=True to keep the socket open as a port reservation,
+ preventing other calls from getting the same port. The caller is
+ responsible for closing the socket before the port is actually bound
+ by the target service (e.g. NCCL, uvicorn).
+ """
+ family = socket.AF_INET6 if is_valid_ipv6_address(address) else socket.AF_INET
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind((address, 0))
-
port = sock.getsockname()[1]
- return port, sock
+ if with_alive_sock:
+ return port, sock
+ sock.close()
+ return port, None
diff --git a/verl/utils/profiler/torch_profile.py b/verl/utils/profiler/torch_profile.py
index bd59ef54dca..e0ccd09f3d3 100644
--- a/verl/utils/profiler/torch_profile.py
+++ b/verl/utils/profiler/torch_profile.py
@@ -14,6 +14,7 @@
import functools
import os
+from datetime import datetime, timezone
from typing import Callable, Optional
import torch
@@ -34,7 +35,11 @@ def get_torch_profiler(
os.makedirs(save_path, exist_ok=True)
- save_file_name = f"prof_rank-{rank}.json.gz"
+ current_time = datetime.now(tz=timezone.utc).astimezone()
+ timestamp = current_time.strftime("%Y%m%d%H%M%S%f")[:-3]
+ pid = os.getpid()
+
+ save_file_name = f"prof_rank-{rank}_{pid}_{timestamp}.json.gz"
if save_file_prefix:
save_file_name = f"{save_file_prefix}_{save_file_name}"
save_path = os.path.join(save_path, save_file_name)
diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py
index 32dc8f52681..bb8a60bfea6 100644
--- a/verl/utils/py_functional.py
+++ b/verl/utils/py_functional.py
@@ -25,6 +25,8 @@
from types import SimpleNamespace
from typing import Any, Callable, Iterator, Optional
+import numpy as np
+
from verl.utils.metric import Metric
@@ -339,3 +341,28 @@ def convert_to_regular_types(obj):
elif isinstance(obj, dict):
return {k: convert_to_regular_types(v) for k, v in obj.items()}
return obj
+
+
+def convert_nested_value_to_list_recursive(data_item):
+ if isinstance(data_item, dict):
+ return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()}
+ elif isinstance(data_item, list):
+ return [convert_nested_value_to_list_recursive(elem) for elem in data_item]
+ elif isinstance(data_item, np.ndarray):
+ # Convert to list, then recursively process the elements of the new list
+ return convert_nested_value_to_list_recursive(data_item.tolist())
+ else:
+ # Base case: item is already a primitive type (int, str, float, bool, etc.)
+ return data_item
+
+
+def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
+ if len(list_of_dict) == 0:
+ return {}
+ keys = list_of_dict[0].keys()
+ output = {key: [] for key in keys}
+ for data in list_of_dict:
+ for key, item in data.items():
+ assert key in output, f"Key '{key}' is not present in the keys of the first dictionary in the list."
+ output[key].append(item)
+ return output
diff --git a/verl/utils/qat/__init__.py b/verl/utils/qat/__init__.py
new file mode 100644
index 00000000000..6f2a85c814d
--- /dev/null
+++ b/verl/utils/qat/__init__.py
@@ -0,0 +1,57 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+QAT (Quantization-Aware Training) module for verl.
+
+Supports NVFP4 (W4A4 and W4A16) quantization modes for FSDP training.
+
+Module Structure:
+- core.py: QATConfig, apply_qat, enable_qat_fuse (training setup)
+- linear.py: QATLinear layer with Triton kernels for fake quantization
+- quantizer.py: QATQuantizer for true quantization + scale computation utilities
+- vllm_patch.py: Patches for vLLM dynamic weight loading
+
+Usage:
+ from verl.utils.qat import apply_qat, QATConfig
+
+ config = QATConfig(enable=True, mode="w4a16")
+ model = apply_qat(model, config) # Before FSDP wrapping
+"""
+
+from verl.utils.qat.core import (
+ QATConfig,
+ apply_qat,
+ enable_qat_fuse,
+ invalidate_all_scales,
+ load_quantization_config,
+)
+from verl.utils.qat.vllm_patch import (
+ apply_qat_patches,
+ manual_process_weights_after_loading,
+ prepare_qat_for_load_weights,
+)
+
+__all__ = [
+ # Core
+ "QATConfig",
+ "apply_qat",
+ "load_quantization_config",
+ "enable_qat_fuse",
+ "invalidate_all_scales",
+ # vLLM Patch
+ "apply_qat_patches",
+ "manual_process_weights_after_loading",
+ "prepare_qat_for_load_weights",
+]
diff --git a/verl/utils/qat/core.py b/verl/utils/qat/core.py
new file mode 100644
index 00000000000..9f3a1fbe6a8
--- /dev/null
+++ b/verl/utils/qat/core.py
@@ -0,0 +1,196 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""QAT (Quantization-Aware Training) utilities for verl FSDP training."""
+
+import json
+import logging
+import re
+from dataclasses import dataclass, field
+from typing import Any, Optional
+
+import torch.nn as nn
+
+from verl.base_config import BaseConfig
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class QATConfig(BaseConfig):
+ """Unified configuration for QAT (Quantization-Aware Training)."""
+
+ enable: bool = False
+ mode: str = "w4a16"
+ group_size: int = 16
+ ignore_patterns: list[str] = field(default_factory=lambda: ["lm_head", "embed_tokens", "re:.*mlp.gate$"])
+ activation_observer: str = "static_minmax"
+ quantization_config_path: Optional[str] = None
+
+
+def load_quantization_config(qat_config: QATConfig) -> dict[str, Any]:
+ """Load quantization config JSON file from QATConfig."""
+ if not qat_config.quantization_config_path:
+ raise ValueError("quantization_config_path is required when QAT is enabled")
+
+ logger.info(f"Loading QAT quantization config from: {qat_config.quantization_config_path}")
+
+ with open(qat_config.quantization_config_path) as f:
+ quant_config = json.load(f)
+
+ if qat_config.ignore_patterns:
+ original_ignore = quant_config.get("ignore", [])
+ quant_config["ignore"] = qat_config.ignore_patterns
+ if original_ignore != qat_config.ignore_patterns:
+ logger.info(f"Overriding JSON 'ignore' field: {original_ignore} -> {qat_config.ignore_patterns}")
+
+ logger.info("Successfully loaded QAT quantization config")
+ return quant_config
+
+
+def _should_quantize(name: str, module: nn.Module, config: QATConfig) -> bool:
+ """Check if a module should be quantized."""
+ if not isinstance(module, nn.Linear):
+ return False
+
+ for pattern in config.ignore_patterns:
+ if pattern.startswith("re:"):
+ regex = pattern[3:]
+ if re.match(regex, name):
+ logger.debug(f"Ignoring {name} due to regex pattern: {regex}")
+ return False
+ else:
+ if pattern in name:
+ logger.debug(f"Ignoring {name} due to pattern: {pattern}")
+ return False
+
+ if module.in_features % config.group_size != 0:
+ logger.warning(
+ f"Skipping {name}: in_features={module.in_features} not divisible by group_size={config.group_size}"
+ )
+ return False
+
+ return True
+
+
+def apply_qat(
+ model: nn.Module,
+ config: QATConfig | dict[str, Any],
+) -> nn.Module:
+ """Apply QAT to a model by replacing nn.Linear with QATLinear."""
+ from verl.utils.qat.linear import QATLinear, QATMode
+
+ if not isinstance(config, QATConfig):
+ config = QATConfig(**config)
+
+ if not config.enable:
+ logger.info("QAT is disabled, returning original model")
+ return model
+
+ mode = QATMode(config.mode.lower())
+ logger.info(f"Applying QAT with mode={mode.value}, group_size={config.group_size}")
+
+ modules_to_replace = []
+ for name, module in model.named_modules():
+ if _should_quantize(name, module, config):
+ modules_to_replace.append((name, module))
+
+ logger.info(f"Found {len(modules_to_replace)} Linear layers to convert to QAT")
+
+ converted_count = 0
+ for name, module in modules_to_replace:
+ if isinstance(module, QATLinear):
+ continue
+
+ fake_quant_module = QATLinear.from_linear(
+ module,
+ mode=mode,
+ group_size=config.group_size,
+ activation_observer=config.activation_observer,
+ )
+
+ _set_module(model, name, fake_quant_module)
+ converted_count += 1
+
+ logger.info(f"Successfully applied QAT to {converted_count} layers")
+
+ return model
+
+
+def _set_module(model: nn.Module, name: str, new_module: nn.Module):
+ """Set a module in the model by its full name."""
+ parts = name.split(".")
+ parent = model
+ for part in parts[:-1]:
+ parent = getattr(parent, part)
+ setattr(parent, parts[-1], new_module)
+
+
+FUSION_PATTERNS = {
+ "qkv": ["q_proj", "k_proj", "v_proj"],
+ "gate_up": ["gate_proj", "up_proj"],
+}
+
+
+def setup_fusion_siblings(model: nn.Module):
+ """Setup fusion siblings for QKV and GateUp layers."""
+ import weakref
+
+ from verl.utils.qat.linear import QATLinear
+
+ qat_modules = {name: m for name, m in model.named_modules() if isinstance(m, QATLinear)}
+
+ counts = {}
+ for group_name, suffixes in FUSION_PATTERNS.items():
+ groups: dict[str, dict[str, nn.Module]] = {}
+ for name, module in qat_modules.items():
+ for suffix in suffixes:
+ if name.endswith(suffix):
+ parent = name.rsplit(".", 1)[0]
+ groups.setdefault(parent, {})[suffix] = module
+
+ count = 0
+ for parent, projs in groups.items():
+ if len(projs) >= 2:
+ modules = list(projs.values())
+ for i, m in enumerate(modules):
+ siblings = modules[:i] + modules[i + 1 :]
+ m._fusion_siblings_ref = [weakref.ref(s) for s in siblings]
+ count += 1
+ counts[group_name] = count
+
+ logger.info(f"[QAT Fuse] Setup fusion siblings: {counts}")
+ return counts
+
+
+def enable_qat_fuse(model: nn.Module):
+ """Enable QAT fuse mode: sets up fusion siblings for weight scale fusion."""
+ setup_fusion_siblings(model)
+ model._qat_fuse_enabled = True
+ logger.info("[QAT Fuse] Enabled QAT fuse mode")
+
+
+def invalidate_all_scales(model: nn.Module):
+ """Clear all cached weight scales after optimizer.step()."""
+ from verl.utils.qat.linear import QATLinear
+
+ count = 0
+ for module in model.modules():
+ if isinstance(module, QATLinear):
+ module._weight_blockwise_scale = None
+ module._weight_global_scale = None
+ module._cached_weight_amax = None
+ count += 1
+
+ logger.debug(f"[QAT Fuse] Invalidated scales for {count} QATLinear layers")
diff --git a/verl/utils/qat/linear.py b/verl/utils/qat/linear.py
new file mode 100644
index 00000000000..4b6c6bc8f41
--- /dev/null
+++ b/verl/utils/qat/linear.py
@@ -0,0 +1,385 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""QAT FakeQuantized Linear module for NVFP4 (W4A4/W4A16) with FSDP compatibility.
+
+Includes Triton kernels for high-performance FP4 quantization.
+"""
+
+from enum import Enum
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ["QATLinear", "QATMode"]
+
+
+import triton
+import triton.language as tl
+
+_TORCH_TO_TL_DTYPE = {
+ torch.float32: tl.float32,
+ torch.float16: tl.float16,
+ torch.bfloat16: tl.bfloat16,
+}
+FP4_E2M1_MAX: float = 6.0
+FP8_E4M3_MAX: float = 448.0
+
+
+@triton.jit
+def _fp4_fake_quant_kernel(
+ x_ptr,
+ y_ptr,
+ M,
+ N,
+ global_scale_ptr,
+ stride_xm,
+ stride_xn,
+ stride_ym,
+ stride_yn,
+ BLOCK_SIZE: tl.constexpr,
+ TILE_M: tl.constexpr,
+ TILE_N: tl.constexpr,
+ NUM_FP4_BLOCKS: tl.constexpr,
+ OUT_DTYPE: tl.constexpr,
+ FP4_MAX: tl.constexpr,
+ FP8_MAX: tl.constexpr,
+):
+ pid_m = tl.program_id(axis=0)
+ pid_n = tl.program_id(axis=1)
+ row_start = pid_m * TILE_M
+ col_start = pid_n * TILE_N
+
+ x_block_ptr = tl.make_block_ptr(
+ base=x_ptr,
+ shape=(M, N),
+ strides=(stride_xm, stride_xn),
+ offsets=(row_start, col_start),
+ block_shape=(TILE_M, TILE_N),
+ order=(1, 0),
+ )
+ y_block_ptr = tl.make_block_ptr(
+ base=y_ptr,
+ shape=(M, N),
+ strides=(stride_ym, stride_yn),
+ offsets=(row_start, col_start),
+ block_shape=(TILE_M, TILE_N),
+ order=(1, 0),
+ )
+
+ global_scale = tl.load(global_scale_ptr).to(tl.float32)
+ global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12)
+
+ tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
+ tile_reshaped = tl.reshape(tile, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE))
+ x_abs = tl.abs(tile_reshaped)
+
+ block_max = tl.max(x_abs, axis=2, keep_dims=True)
+ block_max_scaled = block_max / (FP4_MAX * global_scale_safe)
+ block_max_scaled = tl.minimum(block_max_scaled, FP8_MAX)
+ block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale
+ block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0)
+
+ block_max_quant_broadcast = tl.broadcast_to(block_max_quant, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE))
+ abs_scaled = x_abs / block_max_quant_broadcast
+
+ q_val = tl.where(
+ abs_scaled <= 0.25,
+ 0.0,
+ tl.where(
+ abs_scaled < 0.75,
+ 0.5,
+ tl.where(
+ abs_scaled <= 1.25,
+ 1.0,
+ tl.where(
+ abs_scaled < 1.75,
+ 1.5,
+ tl.where(
+ abs_scaled <= 2.5,
+ 2.0,
+ tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, FP4_MAX)),
+ ),
+ ),
+ ),
+ ),
+ )
+
+ x_rescaled = q_val * block_max_quant_broadcast
+ x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled)
+ tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N))
+
+ tl.store(y_block_ptr, tile_quant.to(OUT_DTYPE), boundary_check=(0, 1))
+
+
+def fp4_fake_quant_weight(
+ weight: torch.Tensor,
+ global_amax: torch.Tensor = None,
+ block_size: int = 16,
+ tile_rows: int = 16,
+ tile_cols: int = 64,
+) -> torch.Tensor:
+ """Apply FP4 fake quantization using Triton kernel."""
+ x_shape = weight.shape
+ x_dtype = weight.dtype
+ x = weight.reshape(-1, x_shape[-1]).contiguous()
+ M, N = x.shape
+ y = torch.empty_like(x)
+
+ stride_xm, stride_xn = x.stride()
+ stride_ym, stride_yn = y.stride()
+
+ tile_cols = max(tile_cols, block_size)
+ tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size
+ num_fp4_blocks = tile_cols_aligned // block_size
+
+ if global_amax is None:
+ global_amax = weight.abs().max().to(torch.float32)
+ global_scale = global_amax.float() / (FP4_E2M1_MAX * FP8_E4M3_MAX)
+
+ grid = (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned))
+
+ _fp4_fake_quant_kernel[grid](
+ x,
+ y,
+ M,
+ N,
+ global_scale,
+ stride_xm,
+ stride_xn,
+ stride_ym,
+ stride_yn,
+ BLOCK_SIZE=block_size,
+ TILE_M=tile_rows,
+ TILE_N=tile_cols_aligned,
+ NUM_FP4_BLOCKS=num_fp4_blocks,
+ OUT_DTYPE=_TORCH_TO_TL_DTYPE[x_dtype],
+ FP4_MAX=FP4_E2M1_MAX,
+ FP8_MAX=FP8_E4M3_MAX,
+ )
+ return y.view(*x_shape)
+
+
+class STEFP4QuantTriton(torch.autograd.Function):
+ """Straight-Through Estimator wrapper for Triton FP4 quantization kernel."""
+
+ @staticmethod
+ def forward(ctx, x: torch.Tensor, global_amax: torch.Tensor, block_size: int) -> torch.Tensor:
+ return fp4_fake_quant_weight(x, global_amax=global_amax, block_size=block_size)
+
+ @staticmethod
+ def backward(ctx, grad_output: torch.Tensor) -> tuple:
+ return grad_output, None, None
+
+
+class QATMode(str, Enum):
+ """QAT quantization mode."""
+
+ W4A4 = "w4a4" # Weight 4-bit, Activation 4-bit (dynamic)
+ W4A16 = "w4a16" # Weight 4-bit, Activation 16-bit (weight only)
+
+
+class QATLinear(nn.Linear):
+ """QAT FakeQuantized Linear layer with FSDP compatibility."""
+
+ _UNINITIALIZED_SCALE = -1.0
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ mode: QATMode = QATMode.W4A4,
+ group_size: int = 16,
+ activation_observer: str = "static_minmax", # Observer strategy for activation global_scale
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ super().__init__(in_features, out_features, bias, device=device, dtype=dtype)
+
+ self.mode = mode
+ self.group_size = group_size
+ self.activation_observer = activation_observer
+
+ self._weight_blockwise_scale: Optional[torch.Tensor] = None
+ self._weight_global_scale: Optional[torch.Tensor] = None
+ self._cached_weight_amax: Optional[torch.Tensor] = None
+ self._fusion_siblings_ref = None
+
+ if mode == QATMode.W4A4:
+ self.register_buffer(
+ "input_global_scale", torch.tensor([self._UNINITIALIZED_SCALE], dtype=torch.float32), persistent=True
+ )
+
+ self.register_buffer(
+ "input_amax", torch.tensor([self._UNINITIALIZED_SCALE], dtype=torch.float32), persistent=True
+ )
+
+ self._ema_decay: float = 0.01
+
+ self.fake_quant_enabled = True
+
+ @classmethod
+ def from_linear(
+ cls,
+ linear: nn.Linear,
+ mode: QATMode = QATMode.W4A4,
+ group_size: int = 16,
+ activation_observer: str = "static_minmax",
+ ) -> "QATLinear":
+ """Create QATLinear from an existing nn.Linear."""
+ has_bias = linear.bias is not None
+
+ new_linear = cls(
+ in_features=linear.in_features,
+ out_features=linear.out_features,
+ bias=has_bias,
+ mode=mode,
+ group_size=group_size,
+ activation_observer=activation_observer,
+ device=linear.weight.device,
+ dtype=linear.weight.dtype,
+ )
+
+ if linear.weight.device != torch.device("meta"):
+ new_linear.weight = nn.Parameter(linear.weight.clone())
+ if has_bias:
+ new_linear.bias = nn.Parameter(linear.bias.clone())
+
+ return new_linear
+
+ def _is_amax_initialized(self) -> bool:
+ """Check if input_amax has been initialized."""
+ if not hasattr(self, "input_amax"):
+ return False
+ return self.input_amax.item() != self._UNINITIALIZED_SCALE
+
+ def _update_input_global_scale(self, x: torch.Tensor):
+ """Update static input_global_scale based on observer strategy."""
+ assert self.mode == QATMode.W4A4, "_update_input_global_scale should only be called in W4A4 mode"
+
+ current_amax = torch.amax(torch.abs(x)).detach().to(torch.float32)
+
+ if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
+ torch.distributed.all_reduce(current_amax, op=torch.distributed.ReduceOp.MAX)
+
+ scale_factor = FP8_E4M3_MAX * FP4_E2M1_MAX
+
+ if self.activation_observer == "memoryless_minmax":
+ new_scale = (scale_factor / (current_amax + 1e-12)).view(1)
+ self.input_global_scale.copy_(new_scale.to(self.input_global_scale.device))
+
+ elif self.activation_observer == "static_minmax":
+ if not self._is_amax_initialized():
+ self.input_amax.copy_(current_amax.view(1).to(self.input_amax.device))
+ else:
+ new_amax = torch.maximum(self.input_amax, current_amax.view(1).to(self.input_amax.device))
+ self.input_amax.copy_(new_amax)
+ amax_f32 = self.input_amax.to(torch.float32)
+ new_scale = (scale_factor / (amax_f32 + 1e-12)).float().view(1)
+ self.input_global_scale.copy_(new_scale.to(self.input_global_scale.device))
+
+ elif self.activation_observer == "minmax":
+ if not self._is_amax_initialized():
+ self.input_amax.copy_(current_amax.view(1).to(self.input_amax.device))
+ else:
+ new_amax = (1 - self._ema_decay) * self.input_amax + self._ema_decay * current_amax.view(1).to(
+ self.input_amax.device
+ )
+ self.input_amax.copy_(new_amax)
+ amax_f32 = self.input_amax.to(torch.float32)
+ new_scale = (scale_factor / (amax_f32 + 1e-12)).float().view(1)
+ self.input_global_scale.copy_(new_scale.to(self.input_global_scale.device))
+
+ else:
+ raise ValueError(f"Unknown activation_observer: {self.activation_observer}")
+
+ def _fake_quantize_weight(self, weight: torch.Tensor) -> torch.Tensor:
+ """Apply fake quantization to weight tensor using Triton kernel."""
+ with torch.no_grad():
+ if self._cached_weight_amax is not None:
+ global_amax = self._cached_weight_amax
+ else:
+ siblings_ref = getattr(self, "_fusion_siblings_ref", None)
+
+ if siblings_ref is not None:
+ siblings = [ref() for ref in siblings_ref if ref() is not None]
+ siblings = [s for s in siblings if s.weight.device != torch.device("meta")]
+
+ for sibling in siblings:
+ sibling_amax = getattr(sibling, "_cached_weight_amax", None)
+ if sibling_amax is not None:
+ global_amax = sibling_amax
+ self._cached_weight_amax = global_amax
+ break
+ else:
+ all_modules = [self] + siblings
+ amaxes = [m.weight.abs().max().to(torch.float32) for m in all_modules]
+ global_amax = torch.max(torch.stack(amaxes))
+
+ self._cached_weight_amax = global_amax
+ for sibling in siblings:
+ sibling._cached_weight_amax = global_amax
+ else:
+ global_amax = weight.abs().max().to(torch.float32)
+ self._cached_weight_amax = global_amax
+
+ if self._weight_global_scale is None:
+ self._weight_global_scale = global_amax.float() / (FP4_E2M1_MAX * FP8_E4M3_MAX)
+
+ result = STEFP4QuantTriton.apply(weight, global_amax, self.group_size)
+
+ return result
+
+ def _fake_quantize_activation(self, x: torch.Tensor) -> torch.Tensor:
+ """Apply fake quantization to activation tensor (W4A4 mode only)."""
+ original_shape = x.shape
+
+ if x.dim() == 3:
+ x_2d = x.view(-1, x.shape[-1])
+ else:
+ x_2d = x
+
+ if self.training:
+ self._update_input_global_scale(x_2d)
+
+ if self.input_global_scale.item() == self._UNINITIALIZED_SCALE:
+ raise RuntimeError("W4A4 input_global_scale uninitialized. Load PTQ model first.")
+
+ global_amax = (FP4_E2M1_MAX * FP8_E4M3_MAX) / self.input_global_scale.to(x.device)
+ result = STEFP4QuantTriton.apply(x_2d, global_amax, self.group_size)
+ return result.view(original_shape)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass with fake quantization."""
+ if not self.fake_quant_enabled:
+ return F.linear(x, self.weight, self.bias)
+
+ weight_fq = self._fake_quantize_weight(self.weight)
+
+ if self.mode == QATMode.W4A4:
+ x_fq = self._fake_quantize_activation(x)
+ else:
+ x_fq = x
+
+ return F.linear(x_fq, weight_fq, self.bias)
+
+ def extra_repr(self) -> str:
+ return (
+ f"in_features={self.in_features}, out_features={self.out_features}, "
+ f"bias={self.bias is not None}, mode={self.mode.value}, "
+ f"group_size={self.group_size}, fake_quant_enabled={self.fake_quant_enabled}"
+ )
diff --git a/verl/utils/qat/quantizer.py b/verl/utils/qat/quantizer.py
new file mode 100644
index 00000000000..5ed96d46a3e
--- /dev/null
+++ b/verl/utils/qat/quantizer.py
@@ -0,0 +1,308 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Fast NVFP4 Quantizer for verl FSDP training.
+
+Directly computes scales and quantizes weights using compressed_tensors APIs.
+Includes scale computation utilities for weight quantization.
+"""
+
+import logging
+import os
+import re
+from typing import Generator, Iterable, Optional
+
+import torch
+from compressed_tensors.compressors.quantized_compressors.fp4_quantized import NVFP4PackedCompressor
+from compressed_tensors.quantization.quant_args import (
+ FP4_E2M1_DATA,
+ FP8_E4M3_DATA,
+ QuantizationArgs,
+ QuantizationStrategy,
+ QuantizationType,
+)
+from compressed_tensors.quantization.utils.helpers import generate_gparam
+
+from verl.utils.device import get_device_name, get_torch_device
+
+logger = logging.getLogger(__name__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+
+_LAYER_IDX_RE = re.compile(r"layers\.(\d+)\.")
+
+
+def compute_blockwise_scale(
+ weight: torch.Tensor,
+ global_scale: torch.Tensor,
+ group_size: int = 16,
+) -> torch.Tensor:
+ """Compute blockwise scale using pre-computed global_scale (for fusion).
+ Returns FP8 E4M3 blockwise scale tensor.
+ """
+ out_features, in_features = weight.shape
+ num_groups = in_features // group_size
+ weight_reshaped = weight.view(out_features, num_groups, group_size)
+ block_max = torch.amax(torch.abs(weight_reshaped), dim=-1).to(torch.float32)
+
+ local_scale = block_max / FP4_E2M1_DATA.max
+ blockwise_scale_f32 = torch.clamp(
+ global_scale * local_scale,
+ min=-FP8_E4M3_DATA.max,
+ max=FP8_E4M3_DATA.max,
+ )
+
+ blockwise_scale = blockwise_scale_f32.to(torch.float8_e4m3fn)
+ eps = torch.finfo(torch.float8_e4m3fn).eps
+ blockwise_scale = torch.where(
+ blockwise_scale == 0,
+ torch.tensor(eps, dtype=blockwise_scale.dtype, device=weight.device),
+ blockwise_scale,
+ )
+
+ return blockwise_scale
+
+
+# Fusion patterns for transformer models
+FUSE_PATTERNS = {
+ "qkv": ["q_proj", "k_proj", "v_proj"],
+ "gate_up": ["gate_proj", "up_proj"],
+}
+
+
+def fuse_global_scales(
+ layer_global_scales: dict[str, torch.Tensor],
+ strategy: str = "min",
+) -> dict[str, torch.Tensor]:
+ """Fuse global scales for QKV/GateUp groups (take min across group)."""
+ if not layer_global_scales:
+ return {}
+
+ # Group by parent module
+ parent_to_children: dict[str, dict[str, str]] = {}
+ for name in layer_global_scales:
+ parent, child = name.rsplit(".", 1) if "." in name else ("", name)
+ parent_to_children.setdefault(parent, {})[child] = name
+
+ fused_scales = {}
+ processed = set()
+
+ for parent, children in parent_to_children.items():
+ for _, patterns in FUSE_PATTERNS.items():
+ matched = [children[p] for p in patterns if p in children]
+ if len(matched) == len(patterns):
+ group_scales = [layer_global_scales[n] for n in matched]
+ if strategy == "min":
+ fused_scale = torch.min(torch.cat(group_scales)).reshape([1])
+ else:
+ raise ValueError(f"Unknown fuse strategy: {strategy}")
+ for layer_name in matched:
+ fused_scales[layer_name] = fused_scale.clone()
+ processed.add(layer_name)
+
+ for name, scale in layer_global_scales.items():
+ if name not in processed:
+ fused_scales[name] = scale
+
+ return fused_scales
+
+
+class QATQuantizer:
+ """Quantizer for QAT-trained weights using compressed_tensors APIs."""
+
+ def __init__(
+ self,
+ mode: str = "w4a16",
+ group_size: int = 16,
+ ignore_patterns: Optional[list] = None,
+ device: Optional[torch.device] = None,
+ param_dtype: Optional[torch.dtype] = None,
+ ):
+ self.mode = mode.lower()
+ self._is_w4a4 = self.mode == "w4a4" # W4A4 needs input_global_scale
+ self.group_size = group_size
+ self.ignore_patterns = ignore_patterns or ["lm_head", "embed_tokens", "re:.*mlp.gate$"]
+ self.device = device or torch.device(get_device_name())
+ self.param_dtype = param_dtype
+
+ self._compressor = NVFP4PackedCompressor()
+ self._quant_args = QuantizationArgs(
+ num_bits=4,
+ type=QuantizationType.FLOAT,
+ symmetric=True,
+ strategy=QuantizationStrategy.TENSOR_GROUP,
+ group_size=group_size,
+ scale_dtype=FP8_E4M3_DATA.dtype,
+ )
+
+ def _should_quantize(self, name: str, tensor: torch.Tensor) -> bool:
+ """Check if parameter should be quantized."""
+ if not name.endswith(".weight"):
+ return False
+ if tensor.dim() != 2:
+ return False
+ if tensor.shape[1] % self.group_size != 0:
+ return False
+
+ module_name = name.rsplit(".weight", 1)[0]
+
+ for pattern in self.ignore_patterns:
+ if pattern.startswith("re:"):
+ # Regex pattern - use re.match like vLLM does
+ regex = pattern[3:]
+ if re.match(regex, module_name):
+ return False
+ else:
+ if pattern in module_name:
+ return False
+ return True
+
+ @staticmethod
+ def _extract_layer_idx(name: str) -> Optional[int]:
+ """Extract decoder layer index from parameter name."""
+ match = _LAYER_IDX_RE.search(name)
+ return int(match.group(1)) if match else None
+
+ def _process_layer_group(
+ self,
+ layer_idx: Optional[int],
+ layer_params: dict[str, torch.Tensor],
+ input_global_scales: dict[str, torch.Tensor],
+ output_device: torch.device,
+ ) -> list[tuple[str, torch.Tensor]]:
+ """Quantize one decoder layer's buffered params. Returns list of (name, tensor)."""
+ layer_weights = {}
+ layer_passthrough = {}
+
+ for name, tensor in layer_params.items():
+ if "input_global_scale" in name or "input_amax" in name:
+ continue
+
+ if self._should_quantize(name, tensor):
+ layer_name = name.rsplit(".weight", 1)[0]
+ layer_weights[layer_name] = (name, tensor)
+ else:
+ layer_passthrough[name] = tensor
+
+ if layer_idx is None and layer_weights:
+ raise RuntimeError(
+ f"[QAT Quantizer] Unexpected quantizable weights outside decoder layers: "
+ f"{list(layer_weights.keys())}. These should be in ignore_patterns."
+ )
+
+ if not layer_weights:
+ return [(name, tensor.to(output_device)) for name, tensor in layer_passthrough.items()]
+
+ # Move weights to GPU, compute global scales
+ weights_on_gpu = {}
+ layer_global_scales = {}
+
+ for layer_name, (_, tensor) in layer_weights.items():
+ weight_gpu = tensor.to(device=self.device, dtype=self.param_dtype)
+ weights_on_gpu[layer_name] = weight_gpu
+ amax = torch.amax(torch.abs(weight_gpu)).to(torch.float32)
+ layer_global_scales[layer_name] = generate_gparam(
+ -amax.unsqueeze(0),
+ amax.unsqueeze(0),
+ scale_data=FP8_E4M3_DATA,
+ quant_data=FP4_E2M1_DATA,
+ dtype=torch.float32,
+ )
+
+ fused_global_scales = fuse_global_scales(layer_global_scales, strategy="min")
+
+ results = []
+
+ for layer_name, weight_gpu in weights_on_gpu.items():
+ fused_global_scale = fused_global_scales[layer_name]
+ weight_scale = compute_blockwise_scale(weight_gpu, fused_global_scale, self.group_size)
+ weight_packed = self._compressor.compress_weight(
+ weight=weight_gpu,
+ scale=weight_scale.float(),
+ global_scale=fused_global_scale,
+ quantization_args=self._quant_args,
+ )["weight_packed"]
+
+ results.append((f"{layer_name}.weight_packed", weight_packed.to(output_device)))
+ results.append((f"{layer_name}.weight_scale", weight_scale.to(output_device)))
+ results.append((f"{layer_name}.weight_global_scale", fused_global_scale.to(output_device)))
+
+ if self._is_w4a4:
+ if layer_name in input_global_scales:
+ results.append(
+ (
+ f"{layer_name}.input_global_scale",
+ input_global_scales[layer_name].float().to(output_device),
+ )
+ )
+ else:
+ raise ValueError(
+ f"W4A4 mode requires input_global_scale for layer '{layer_name}', "
+ f"but it's not found or uninitialized (-1.0)."
+ )
+
+ del weights_on_gpu, layer_global_scales, fused_global_scales
+
+ for name, tensor in layer_passthrough.items():
+ results.append((name, tensor.to(output_device)))
+
+ return results
+
+ def quantize_with_fusion(
+ self,
+ params: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]],
+ target_device: Optional[torch.device] = None,
+ ) -> Generator[tuple[str, torch.Tensor], None, None]:
+ """Streaming quantize: consume input layer by layer, yield (name, tensor) pairs."""
+ if isinstance(params, dict):
+ params = params.items()
+
+ output_device = target_device or torch.device("cpu")
+
+ _sentinel = object()
+ current_layer_idx = _sentinel
+ layer_buffer: dict[str, torch.Tensor] = {}
+ input_global_scales: dict[str, torch.Tensor] = {}
+ for name, tensor in params:
+ tensor_cpu = tensor.to("cpu") if tensor.is_cuda else tensor
+ layer_idx = self._extract_layer_idx(name)
+
+ # Collect input_global_scales for W4A4 as we go
+ if self._is_w4a4 and "input_global_scale" in name:
+ scale_layer_name = name.replace(".input_global_scale", "")
+ if tensor_cpu.numel() == 1 and tensor_cpu.item() == -1.0:
+ logger.warning(f"W4A4: {scale_layer_name} input_global_scale is uninitialized")
+ else:
+ input_global_scales[scale_layer_name] = tensor_cpu
+
+ # Layer boundary: flush previous layer
+ if layer_idx != current_layer_idx and current_layer_idx is not _sentinel and layer_buffer:
+ yield from self._process_layer_group(
+ current_layer_idx, layer_buffer, input_global_scales, output_device
+ )
+ layer_buffer = {}
+
+ current_layer_idx = layer_idx
+ layer_buffer[name] = tensor_cpu
+
+ # Flush last buffered layer
+ if layer_buffer:
+ yield from self._process_layer_group(current_layer_idx, layer_buffer, input_global_scales, output_device)
+
+ get_torch_device().empty_cache()
+
+
+__all__ = [
+ "QATQuantizer",
+]
diff --git a/verl/utils/qat/vllm_patch.py b/verl/utils/qat/vllm_patch.py
new file mode 100644
index 00000000000..77d2ec17eeb
--- /dev/null
+++ b/verl/utils/qat/vllm_patch.py
@@ -0,0 +1,828 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+vLLM NVFP4 Patches for Dynamic Weight Updates.
+
+Enables dynamic weight reloading for NVFP4 quantized models in vLLM.
+
+Supported schemes:
+- Dense: W4A16-FP4, W4A4-FP4
+- MoE: NVFP4-MoE
+"""
+
+import logging
+import os
+from typing import Optional
+from unittest.mock import patch
+
+import torch
+from torch.nn import Parameter
+
+from verl.utils.device import get_device_name
+
+logger = logging.getLogger(__name__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+
+
+class ParamMetaDict(dict):
+ """
+ Dict-like class for parameter management with metadata-based rebuild and tensor swap.
+
+ Supports:
+ - Rebuild of deleted parameters from saved metadata
+ - Tensor Swap for parameters with shape changes (address stability for CUDA Graph)
+ """
+
+ def __init__(self, model: torch.nn.Module, device: Optional[torch.device] = None):
+ """
+ Initialize ParamMetaDict from a model.
+
+ Args:
+ model: vLLM model (may be wrapped in ModelRunner)
+ device: Device for created parameters
+ """
+ super().__init__()
+ self.device = device
+
+ # Get the actual model (handle vLLM's wrapper structure)
+ actual_model = model
+ if hasattr(model, "model"):
+ actual_model = model.model
+ self._model = actual_model
+
+ # Build mappings by scanning all modules
+ self._layer_meta_cache: dict[str, dict] = {} # Cache of _hf_param_meta
+ self._tensor_swap_layers: dict[str, dict] = {} # Layers needing tensor swap
+
+ self._build_mappings()
+
+ # Initialize with current parameters
+ for name, param in actual_model.named_parameters():
+ self[name] = param
+
+ def _build_mappings(self):
+ """Build layer metadata cache for rebuild and tensor swap."""
+ for layer_name, module in self._model.named_modules():
+ # Check for _hf_param_meta which indicates this layer has HF format params
+ if hasattr(module, "_hf_param_meta"):
+ self._layer_meta_cache[layer_name] = {
+ "module": module,
+ "meta": module._hf_param_meta,
+ }
+
+ # Check for tensor swap layers (weight_scale with shape change)
+ if "weight_scale" in module._hf_param_meta:
+ marlin_refs = getattr(module, "_marlin_tensor_refs", {})
+ if "weight_scale" in marlin_refs:
+ self._tensor_swap_layers[layer_name] = {
+ "module": module,
+ "marlin_ref": marlin_refs["weight_scale"],
+ "hf_meta": module._hf_param_meta["weight_scale"],
+ }
+
+ # MoE layers (w13_weight_scale, w2_weight_scale)
+ if "w13_weight_scale" in module._hf_param_meta:
+ marlin_refs = getattr(module, "_marlin_tensor_refs", {})
+ if "w13_weight_scale" in marlin_refs:
+ self._tensor_swap_layers[f"{layer_name}.w13"] = {
+ "module": module,
+ "param_name": "w13_weight_scale",
+ "marlin_ref": marlin_refs["w13_weight_scale"],
+ "hf_meta": module._hf_param_meta["w13_weight_scale"],
+ }
+ if "w2_weight_scale" in marlin_refs:
+ self._tensor_swap_layers[f"{layer_name}.w2"] = {
+ "module": module,
+ "param_name": "w2_weight_scale",
+ "marlin_ref": marlin_refs["w2_weight_scale"],
+ "hf_meta": module._hf_param_meta["w2_weight_scale"],
+ }
+
+ def _try_rebuild(self, key: str) -> Optional[Parameter]:
+ """
+ Try to rebuild a parameter from metadata if it was deleted.
+
+ Args:
+ key: Full parameter name
+
+ Returns:
+ Rebuilt parameter or None if cannot rebuild
+ """
+ # Extract layer name and param name
+ parts = key.rsplit(".", 1)
+ if len(parts) != 2:
+ return None
+
+ layer_name, param_name = parts
+
+ # Check if we have metadata for this layer
+ if layer_name not in self._layer_meta_cache:
+ return None
+
+ cache_entry = self._layer_meta_cache[layer_name]
+ module = cache_entry["module"]
+ meta = cache_entry["meta"]
+
+ # Check if this param needs rebuild
+ if param_name not in meta:
+ return None
+
+ # Already exists on module?
+ if hasattr(module, param_name):
+ param = getattr(module, param_name)
+ if param is not None:
+ return param
+
+ # Rebuild from metadata
+ new_param = _create_param_from_meta(module, param_name, meta[param_name], self.device)
+ module.register_parameter(param_name, new_param)
+ return new_param
+
+ def prepare_for_reload(self) -> None:
+ """Replace Marlin-format tensors with HF-shape tensors for reload."""
+ for layer_name, swap_info in self._tensor_swap_layers.items():
+ module = swap_info["module"]
+ param_name = swap_info.get("param_name", "weight_scale")
+ hf_meta = swap_info["hf_meta"]
+ if hasattr(module, param_name):
+ new_param = _create_param_from_meta(module, param_name, hf_meta, self.device)
+ setattr(module, param_name, new_param)
+
+ def __getitem__(self, key: str) -> Parameter:
+ """Get parameter with rebuild support."""
+ # Try standard lookup first
+ if key in dict.keys(self):
+ return super().__getitem__(key)
+
+ # Try rebuild from metadata
+ param = self._try_rebuild(key)
+ if param is not None:
+ self[key] = param
+ return param
+
+ raise KeyError(f"Parameter not found: {key}")
+
+ def __contains__(self, key: str) -> bool:
+ """Check if parameter exists (with rebuild check)."""
+ if super().__contains__(key):
+ return True
+
+ # Check if can rebuild from metadata
+ parts = key.rsplit(".", 1)
+ if len(parts) == 2:
+ layer_name, param_name = parts
+ if layer_name in self._layer_meta_cache:
+ meta = self._layer_meta_cache[layer_name]["meta"]
+ if param_name in meta:
+ return True
+
+ return False
+
+ def get(self, key: str, default=None):
+ """Get parameter with default."""
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+
+def _create_param_from_meta(
+ module: torch.nn.Module,
+ param_name: str,
+ meta: dict,
+ device: Optional[torch.device] = None,
+) -> Parameter:
+ """Create a Parameter from saved metadata. Used by rebuild and tensor swap."""
+ shape = meta["shape"]
+ dtype = meta["dtype"]
+ dev = device or meta.get("device", get_device_name())
+ param_class = meta.get("param_class", Parameter)
+
+ weight_loaders = getattr(module, "_weight_loaders", {})
+ weight_loader = weight_loaders.get(param_name)
+
+ data = torch.empty(shape, dtype=dtype, device=dev)
+
+ try:
+ if param_class is not Parameter and weight_loader is not None:
+ kwargs = {"data": data, "weight_loader": weight_loader}
+ if "input_dim" in meta:
+ kwargs["input_dim"] = meta["input_dim"]
+ if "output_dim" in meta:
+ kwargs["output_dim"] = meta["output_dim"]
+ new_param = param_class(**kwargs)
+ else:
+ new_param = Parameter(data, requires_grad=False)
+ if weight_loader is not None:
+ new_param.weight_loader = weight_loader
+ except Exception as e:
+ logger.warning(f"Failed to create param {param_name} with class {param_class}: {e}, using Parameter")
+ new_param = Parameter(data, requires_grad=False)
+ if weight_loader is not None:
+ new_param.weight_loader = weight_loader
+
+ if "quant_method" in meta:
+ new_param.quant_method = meta["quant_method"]
+
+ return new_param
+
+
+def save_param_meta(layer: torch.nn.Module, param_name: str):
+ """Save parameter metadata for rebuild."""
+ if not hasattr(layer, "_hf_param_meta"):
+ layer._hf_param_meta = {}
+
+ param = getattr(layer, param_name, None)
+ if param is None:
+ return
+
+ meta = {
+ "shape": tuple(param.shape),
+ "dtype": param.dtype,
+ "device": str(param.device),
+ "param_class": type(param), # Save the actual parameter class
+ }
+
+ # Save vLLM-specific attributes needed for reconstruction
+ if hasattr(param, "_input_dim"):
+ meta["input_dim"] = param._input_dim
+ if hasattr(param, "_output_dim"):
+ meta["output_dim"] = param._output_dim
+
+ # Save MoE-specific attributes (quant_method is required by weight_loader)
+ if hasattr(param, "quant_method"):
+ meta["quant_method"] = param.quant_method
+
+ layer._hf_param_meta[param_name] = meta
+
+
+def _check_first_call(layer: torch.nn.Module) -> bool:
+ """Check if this is the first process_weights call, and increment counter."""
+ count = getattr(layer, "_process_weights_call_count", 0)
+ layer._process_weights_call_count = count + 1
+ return count == 0
+
+
+# Dense W4A16 Patches
+def patched_w4a16_process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ """Patched process_weights_after_loading for W4A16 Dense layer."""
+ import vllm._custom_ops as ops
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
+ marlin_make_workspace_new,
+ marlin_permute_scales,
+ nvfp4_marlin_process_global_scale,
+ nvfp4_marlin_process_scales,
+ )
+
+ is_first_call = _check_first_call(layer)
+
+ group_size = 16
+ part_size_n = layer.output_size_per_partition
+ part_size_k = layer.input_size_per_partition
+ device = layer.weight_packed.device
+ param_dtype = getattr(layer, "params_dtype", torch.float16)
+
+ # Save metadata (first call only)
+ if is_first_call:
+ save_param_meta(layer, "weight_packed")
+ save_param_meta(layer, "weight_global_scale")
+ save_param_meta(layer, "weight_scale")
+ if not hasattr(layer, "_weight_loaders"):
+ layer._weight_loaders = {}
+ for pname in ["weight_packed", "weight_global_scale", "weight_scale"]:
+ param = getattr(layer, pname, None)
+ if param is not None and hasattr(param, "weight_loader"):
+ layer._weight_loaders[pname] = param.weight_loader
+
+ # Get HF format data
+ weight_packed_hf = layer.weight_packed.data
+ weight_global_scale_hf = layer.weight_global_scale.data
+ weight_scale_hf = layer.weight_scale.data
+
+ # Create workspace (first call only)
+ if is_first_call:
+ layer.workspace = marlin_make_workspace_new(device)
+
+ # Convert to Marlin format
+ perm = torch.empty(0, dtype=torch.int, device=device)
+ qweight = weight_packed_hf.view(torch.int32).T.contiguous()
+ marlin_weight = ops.gptq_marlin_repack(
+ b_q_weight=qweight,
+ perm=perm,
+ size_k=part_size_k,
+ size_n=part_size_n,
+ num_bits=4,
+ is_a_8bit=False,
+ )
+
+ weight_scale = weight_scale_hf.T.contiguous().to(param_dtype)
+ weight_scale_permuted = marlin_permute_scales(
+ s=weight_scale,
+ size_k=part_size_k,
+ size_n=part_size_n,
+ group_size=group_size,
+ is_a_8bit=False,
+ )
+ marlin_weight_scale = nvfp4_marlin_process_scales(weight_scale_permuted)
+
+ weight_scale_2_raw = (1.0 / weight_global_scale_hf.max()).to(param_dtype)
+ marlin_weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2_raw)
+
+ # Update compute parameters
+ if is_first_call:
+ layer.weight = Parameter(marlin_weight, requires_grad=False)
+ layer.weight_scale = Parameter(marlin_weight_scale, requires_grad=False)
+ layer.weight_scale_2 = Parameter(marlin_weight_scale_2, requires_grad=False)
+ if not hasattr(layer, "_marlin_tensor_refs"):
+ layer._marlin_tensor_refs = {}
+ layer._marlin_tensor_refs["weight_scale"] = layer.weight_scale.data
+ else:
+ layer.weight.data.copy_(marlin_weight)
+ layer.weight_scale_2.data.copy_(marlin_weight_scale_2)
+ marlin_scale_ref = layer._marlin_tensor_refs.get("weight_scale")
+ if marlin_scale_ref is not None:
+ marlin_scale_ref.copy_(marlin_weight_scale)
+ layer.weight_scale = Parameter(marlin_scale_ref, requires_grad=False)
+ else:
+ logger.warning("W4A16: _marlin_tensor_refs['weight_scale'] not found")
+ layer.weight_scale = Parameter(marlin_weight_scale, requires_grad=False)
+
+ # Delete HF parameters
+ if hasattr(layer, "weight_packed"):
+ delattr(layer, "weight_packed")
+ if hasattr(layer, "weight_global_scale"):
+ delattr(layer, "weight_global_scale")
+
+
+def patched_w4a4_process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ """Patched process_weights_after_loading for W4A4 Dense (all backends)."""
+ from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale
+
+ is_first_call = _check_first_call(layer)
+
+ _W4A4_HF_PARAMS = ["weight_packed", "weight_scale", "weight_global_scale", "input_global_scale"]
+
+ if is_first_call:
+ for pname in _W4A4_HF_PARAMS:
+ save_param_meta(layer, pname)
+ if not hasattr(layer, "_weight_loaders"):
+ layer._weight_loaders = {}
+ for pname in _W4A4_HF_PARAMS:
+ param = getattr(layer, pname, None)
+ if param is not None and hasattr(param, "weight_loader"):
+ layer._weight_loaders[pname] = param.weight_loader
+
+ weight_packed_data = layer.weight_packed.data
+ weight_scale_data = layer.weight_scale.data
+ input_global_scale_data = layer.input_global_scale.data
+ weight_global_scale_data = layer.weight_global_scale.data
+
+ global_input_scale = input_global_scale_data.max().to(torch.float32)
+ global_weight_scale = weight_global_scale_data.max().to(torch.float32)
+
+ if self.backend == "flashinfer-trtllm":
+ from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
+
+ epilogue_tile_m = 128
+ processed_weight = shuffle_matrix_a(weight_packed_data.view(torch.uint8), epilogue_tile_m)
+ processed_weight_scale = (
+ shuffle_matrix_sf_a(weight_scale_data.view(torch.uint8), epilogue_tile_m)
+ .reshape(weight_scale_data.shape)
+ .view(torch.float8_e4m3fn)
+ )
+ elif self.backend == "fbgemm":
+ processed_weight_scale = swizzle_blockscale(weight_scale_data).view(-1).view(torch.uint8)
+ processed_weight = weight_packed_data
+ else:
+ # cutlass / flashinfer-cutlass
+ processed_weight_scale = swizzle_blockscale(weight_scale_data)
+ processed_weight = weight_packed_data
+
+ alpha = 1.0 / (global_input_scale * global_weight_scale)
+
+ if is_first_call:
+ layer.weight_packed = Parameter(processed_weight, requires_grad=False)
+ layer.weight_scale = Parameter(processed_weight_scale, requires_grad=False)
+ layer.input_global_scale = Parameter(global_input_scale, requires_grad=False)
+ layer.weight_global_scale = Parameter(global_weight_scale, requires_grad=False)
+ layer.alpha = Parameter(alpha, requires_grad=False)
+
+ if not hasattr(layer, "_marlin_tensor_refs"):
+ layer._marlin_tensor_refs = {}
+ layer._marlin_tensor_refs["weight_packed"] = layer.weight_packed.data
+ layer._marlin_tensor_refs["weight_scale"] = layer.weight_scale.data
+ layer._marlin_tensor_refs["input_global_scale"] = layer.input_global_scale.data
+ layer._marlin_tensor_refs["weight_global_scale"] = layer.weight_global_scale.data
+ layer._marlin_tensor_refs["alpha"] = layer.alpha.data
+ else:
+ refs = layer._marlin_tensor_refs
+ for ref_name, new_data in [
+ ("weight_packed", processed_weight),
+ ("weight_scale", processed_weight_scale),
+ ("input_global_scale", global_input_scale),
+ ("weight_global_scale", global_weight_scale),
+ ("alpha", alpha),
+ ]:
+ ref = refs.get(ref_name)
+ if ref is not None:
+ ref.copy_(new_data)
+ setattr(layer, ref_name, Parameter(ref, requires_grad=False))
+ else:
+ logger.warning(f"W4A4: _marlin_tensor_refs['{ref_name}'] not found, creating new Parameter")
+ setattr(
+ layer,
+ ref_name,
+ Parameter(
+ new_data.clone() if isinstance(new_data, torch.Tensor) else torch.tensor(new_data),
+ requires_grad=False,
+ ),
+ )
+
+
+def _marlin_repack_experts(packed, perm, size_k, size_n, num_experts):
+ """Repack weight for each expert into Marlin format and stack."""
+ import vllm._custom_ops as ops
+
+ result = []
+ for i in range(num_experts):
+ qweight = packed[i].view(torch.int32).T.contiguous()
+ result.append(
+ ops.gptq_marlin_repack(
+ b_q_weight=qweight,
+ perm=perm,
+ size_k=size_k,
+ size_n=size_n,
+ num_bits=4,
+ is_a_8bit=False,
+ )
+ )
+ return torch.stack(result)
+
+
+def _marlin_process_scales_experts(scale_hf, param_dtype, size_k, size_n, group_size, num_experts):
+ """Process scales for each expert into Marlin format and stack."""
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
+ marlin_permute_scales,
+ nvfp4_marlin_process_scales,
+ )
+
+ result = []
+ scales = scale_hf.to(param_dtype)
+ for i in range(num_experts):
+ s = marlin_permute_scales(
+ s=scales[i].T,
+ size_k=size_k,
+ size_n=size_n,
+ group_size=group_size,
+ is_a_8bit=False,
+ )
+ result.append(nvfp4_marlin_process_scales(s))
+ return torch.stack(result)
+
+
+def _process_nvfp4_moe_marlin(self, layer: torch.nn.Module, is_first_call: bool) -> None:
+ """Process MoE layer with MARLIN backend (W4A16)."""
+ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import make_nvfp4_moe_kernel
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
+ marlin_make_workspace_new,
+ nvfp4_marlin_process_global_scale,
+ )
+
+ group_size = 16
+ e = layer.num_experts
+ k = layer.hidden_size
+ n = layer.intermediate_size_per_partition
+ device = layer.w13_weight_packed.device
+ param_dtype = layer.params_dtype
+ w13_num_shards = 2 if self.moe.is_act_and_mul else 1
+
+ if is_first_call:
+ layer.workspace = marlin_make_workspace_new(device, 4)
+
+ perm = torch.empty(0, dtype=torch.int, device=device)
+
+ if self.moe.is_act_and_mul and not torch.allclose(
+ layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1]
+ ):
+ logger.warning("w1_weight_global_scale must match w3_weight_global_scale. Accuracy may be affected.")
+
+ size_n_w13, size_k_w13 = n * w13_num_shards, k
+ size_n_w2, size_k_w2 = k, n
+
+ w13_weight_marlin = _marlin_repack_experts(layer.w13_weight_packed.data, perm, size_k_w13, size_n_w13, e)
+ w2_weight_marlin = _marlin_repack_experts(layer.w2_weight_packed.data, perm, size_k_w2, size_n_w2, e)
+ w13_weight_scale_marlin = _marlin_process_scales_experts(
+ layer.w13_weight_scale.data, param_dtype, size_k_w13, size_n_w13, group_size, e
+ )
+ w2_weight_scale_marlin = _marlin_process_scales_experts(
+ layer.w2_weight_scale.data, param_dtype, size_k_w2, size_n_w2, group_size, e
+ )
+
+ # Process global scales
+ w13_scale_2 = 1.0 / layer.w13_weight_global_scale[:, 0]
+ w2_scale_2 = 1.0 / layer.w2_weight_global_scale.data
+ w13_scale_2_processed = nvfp4_marlin_process_global_scale(w13_scale_2.to(param_dtype))
+ w2_scale_2_processed = nvfp4_marlin_process_global_scale(w2_scale_2.to(param_dtype))
+
+ # Update parameters
+ if is_first_call:
+ layer.w13_weight = Parameter(w13_weight_marlin, requires_grad=False)
+ layer.w2_weight = Parameter(w2_weight_marlin, requires_grad=False)
+ layer.w13_weight_scale = Parameter(w13_weight_scale_marlin, requires_grad=False)
+ layer.w2_weight_scale = Parameter(w2_weight_scale_marlin, requires_grad=False)
+ layer.w13_weight_scale_2 = Parameter(w13_scale_2_processed, requires_grad=False)
+ layer.w2_weight_scale_2 = Parameter(w2_scale_2_processed, requires_grad=False)
+ if not hasattr(layer, "_marlin_tensor_refs"):
+ layer._marlin_tensor_refs = {}
+ layer._marlin_tensor_refs["w13_weight_scale"] = layer.w13_weight_scale.data
+ layer._marlin_tensor_refs["w2_weight_scale"] = layer.w2_weight_scale.data
+ else:
+ layer.w13_weight.data.copy_(w13_weight_marlin)
+ layer.w2_weight.data.copy_(w2_weight_marlin)
+ layer.w13_weight_scale_2.data.copy_(w13_scale_2_processed)
+ layer.w2_weight_scale_2.data.copy_(w2_scale_2_processed)
+ w13_marlin_ref = layer._marlin_tensor_refs.get("w13_weight_scale")
+ w2_marlin_ref = layer._marlin_tensor_refs.get("w2_weight_scale")
+ if w13_marlin_ref is not None:
+ w13_marlin_ref.copy_(w13_weight_scale_marlin)
+ layer.w13_weight_scale = Parameter(w13_marlin_ref, requires_grad=False)
+ else:
+ logger.warning("MoE: _marlin_tensor_refs['w13_weight_scale'] not found")
+ layer.w13_weight_scale.data.copy_(w13_weight_scale_marlin)
+ if w2_marlin_ref is not None:
+ w2_marlin_ref.copy_(w2_weight_scale_marlin)
+ layer.w2_weight_scale = Parameter(w2_marlin_ref, requires_grad=False)
+ else:
+ logger.warning("MoE: _marlin_tensor_refs['w2_weight_scale'] not found")
+ layer.w2_weight_scale.data.copy_(w2_weight_scale_marlin)
+
+ layer.w13_input_scale = None
+ layer.w2_input_scale = None
+
+ # Initialize kernel
+ self.moe_quant_config = self.get_fused_moe_quant_config(layer)
+ if self.moe_quant_config is not None and (
+ (not self.moe.moe_parallel_config.use_all2all_kernels) or self.moe.moe_parallel_config.use_naive_all2all_kernels
+ ):
+ self.kernel = make_nvfp4_moe_kernel(
+ moe_quant_config=self.moe_quant_config,
+ moe_config=self.moe,
+ experts_cls=self.experts_cls,
+ )
+
+
+def _process_nvfp4_moe_flashinfer_cutlass(self, layer: torch.nn.Module, is_first_call: bool) -> None:
+ """Process MoE layer with FlashInfer/CUTLASS backend (W4A4)."""
+ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
+ convert_to_nvfp4_moe_kernel_format,
+ make_nvfp4_moe_kernel,
+ )
+ from vllm.model_executor.utils import replace_parameter
+
+ w13_packed = layer.w13_weight_packed.data
+ w2_packed = layer.w2_weight_packed.data
+ w13_scale_hf = layer.w13_weight_scale.data
+ w2_scale_hf = layer.w2_weight_scale.data
+
+ if self.moe.is_act_and_mul and not torch.allclose(
+ layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1]
+ ):
+ logger.warning("w1_weight_global_scale must match w3_weight_global_scale. Accuracy may be affected.")
+ w13_weight_global_scale = layer.w13_weight_global_scale[:, 0].contiguous()
+
+ w13_temp = Parameter(w13_packed.clone(), requires_grad=False)
+ w2_temp = Parameter(w2_packed.clone(), requires_grad=False)
+
+ if is_first_call:
+ layer.w13_weight = w13_temp
+ layer.w2_weight = w2_temp
+
+ (
+ w13,
+ w13_scale,
+ w13_scale_2,
+ a13_scale,
+ w2,
+ w2_scale,
+ w2_scale_2,
+ a2_scale,
+ ) = convert_to_nvfp4_moe_kernel_format(
+ nvfp4_backend=self.nvfp4_backend,
+ layer=layer,
+ w13=w13_temp,
+ w13_scale=w13_scale_hf,
+ w13_scale_2=(1.0 / w13_weight_global_scale),
+ a13_scale=(1.0 / layer.w13_input_global_scale),
+ w2=w2_temp,
+ w2_scale=w2_scale_hf,
+ w2_scale_2=(1.0 / layer.w2_weight_global_scale),
+ a2_scale=(1.0 / layer.w2_input_global_scale),
+ is_act_and_mul=self.moe.is_act_and_mul,
+ )
+
+ # Update parameters
+ if is_first_call:
+ replace_parameter(layer, "w13_weight", w13)
+ replace_parameter(layer, "w2_weight", w2)
+ layer.w13_weight_scale = Parameter(w13_scale, requires_grad=False)
+ layer.w2_weight_scale = Parameter(w2_scale, requires_grad=False)
+ if not hasattr(layer, "_marlin_tensor_refs"):
+ layer._marlin_tensor_refs = {}
+ layer._marlin_tensor_refs["w13_weight_scale"] = layer.w13_weight_scale.data
+ layer._marlin_tensor_refs["w2_weight_scale"] = layer.w2_weight_scale.data
+ else:
+ layer.w13_weight.data.copy_(w13.data)
+ layer.w2_weight.data.copy_(w2.data)
+ w13_scale_ref = layer._marlin_tensor_refs.get("w13_weight_scale")
+ w2_scale_ref = layer._marlin_tensor_refs.get("w2_weight_scale")
+ if w13_scale_ref is not None:
+ w13_scale_ref.copy_(w13_scale)
+ layer.w13_weight_scale = Parameter(w13_scale_ref, requires_grad=False)
+ else:
+ logger.warning("MoE W4A4: _marlin_tensor_refs['w13_weight_scale'] not found")
+ layer.w13_weight_scale.data.copy_(w13_scale)
+ if w2_scale_ref is not None:
+ w2_scale_ref.copy_(w2_scale)
+ layer.w2_weight_scale = Parameter(w2_scale_ref, requires_grad=False)
+ else:
+ logger.warning("MoE W4A4: _marlin_tensor_refs['w2_weight_scale'] not found")
+ layer.w2_weight_scale.data.copy_(w2_scale)
+
+ layer.w13_weight_scale_2 = w13_scale_2
+ layer.w2_weight_scale_2 = w2_scale_2
+ layer.w13_input_scale = a13_scale
+ layer.w2_input_scale = a2_scale
+
+ # Initialize kernel
+ self.moe_quant_config = self.get_fused_moe_quant_config(layer)
+ if self.moe_quant_config is not None and (
+ (not self.moe.moe_parallel_config.use_all2all_kernels) or self.moe.moe_parallel_config.use_naive_all2all_kernels
+ ):
+ self.kernel = make_nvfp4_moe_kernel(
+ moe_quant_config=self.moe_quant_config,
+ moe_config=self.moe,
+ experts_cls=self.experts_cls,
+ )
+
+
+# MoE NVFP4 Patches (entry points)
+def patched_nvfp4_moe_process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ """Patched process_weights_after_loading for NVFP4 MoE layer."""
+ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import NvFp4MoeBackend
+
+ is_first_call = _check_first_call(layer)
+
+ # Save metadata (first call only)
+ if is_first_call:
+ save_param_meta(layer, "w13_weight_packed")
+ save_param_meta(layer, "w2_weight_packed")
+ save_param_meta(layer, "w13_weight_scale")
+ save_param_meta(layer, "w2_weight_scale")
+ if not hasattr(layer, "_weight_loaders"):
+ layer._weight_loaders = {}
+ for pname in ["w13_weight_packed", "w2_weight_packed", "w13_weight_scale", "w2_weight_scale"]:
+ param = getattr(layer, pname, None)
+ if param is not None and hasattr(param, "weight_loader"):
+ layer._weight_loaders[pname] = param.weight_loader
+
+ is_marlin = self.nvfp4_backend == NvFp4MoeBackend.MARLIN
+ if is_marlin:
+ _process_nvfp4_moe_marlin(self, layer, is_first_call)
+ else:
+ _process_nvfp4_moe_flashinfer_cutlass(self, layer, is_first_call)
+
+ # Delete HF parameters
+ if hasattr(layer, "w13_weight_packed"):
+ delattr(layer, "w13_weight_packed")
+ if hasattr(layer, "w2_weight_packed"):
+ delattr(layer, "w2_weight_packed")
+
+
+_PATCH_TARGETS = [
+ # Dense W4A16
+ (
+ "vllm.model_executor.layers.quantization.compressed_tensors.schemes."
+ "compressed_tensors_w4a16_nvfp4.CompressedTensorsW4A16Fp4.process_weights_after_loading",
+ patched_w4a16_process_weights_after_loading,
+ ),
+ # Dense W4A4
+ (
+ "vllm.model_executor.layers.quantization.compressed_tensors.schemes."
+ "compressed_tensors_w4a4_nvfp4.CompressedTensorsW4A4Fp4.process_weights_after_loading",
+ patched_w4a4_process_weights_after_loading,
+ ),
+ # MoE NVFP4
+ (
+ "vllm.model_executor.layers.quantization.compressed_tensors."
+ "compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoEMethod.process_weights_after_loading",
+ patched_nvfp4_moe_process_weights_after_loading,
+ ),
+]
+
+_applied_patches = []
+
+
+def apply_qat_patches():
+ """Apply NVFP4 patches to support dynamic weight updates. Call before model loading."""
+ global _applied_patches
+
+ if _applied_patches:
+ logger.warning("QAT patches already applied, skipping")
+ return _applied_patches
+
+ logger.info("Applying NVFP4 patches for dynamic weight loading...")
+
+ for target, replacement in _PATCH_TARGETS:
+ p = patch(target, replacement)
+ _applied_patches.append(p)
+ p.start()
+
+ logger.info(f"Applied {len(_applied_patches)} NVFP4 patches for dynamic weight loading")
+ return _applied_patches
+
+
+def prepare_qat_for_load_weights(model, device=None):
+ """
+ Prepare QAT model for weight loading. Call ONCE before multi-bucket weight loading.
+
+ Args:
+ model: vLLM model
+ device: Device for created parameters
+ """
+ inner_model = model
+ if hasattr(model, "model"):
+ inner_model = model.model
+
+ param_meta = ParamMetaDict(inner_model, device=device)
+
+ param_meta.prepare_for_reload()
+ logger.info(f"[prepare_qat] Tensor swap prepared for {len(param_meta._tensor_swap_layers)} layers")
+
+ # Rebuild deleted (W4A16) or overwritten (W4A4) params back to HF format
+ rebuilt_count = 0
+ for layer_name, cache_entry in param_meta._layer_meta_cache.items():
+ module = cache_entry["module"]
+ for param_name, pm in cache_entry["meta"].items():
+ existing = getattr(module, param_name, None)
+ if existing is not None:
+ hf_shape = tuple(pm["shape"])
+ hf_dtype = pm["dtype"]
+ if (
+ tuple(existing.shape) == hf_shape
+ and existing.dtype == hf_dtype
+ and hasattr(existing, "weight_loader")
+ ):
+ continue
+ new_param = _create_param_from_meta(module, param_name, pm, device)
+ module.register_parameter(param_name, new_param)
+ rebuilt_count += 1
+
+ logger.info(f"[prepare_qat] Rebuilt {rebuilt_count} parameters")
+ inner_model._param_meta_for_restore = param_meta
+ return param_meta
+
+
+def manual_process_weights_after_loading(model):
+ """Trigger weight post-processing for all quantized layers after load_weights."""
+ dense_count = 0
+ moe_count = 0
+
+ actual_model = model
+ if hasattr(model, "model"):
+ actual_model = model.model
+
+ for module in actual_model.modules():
+ if hasattr(module, "scheme"):
+ module.scheme.process_weights_after_loading(module)
+ dense_count += 1
+
+ quant_method = getattr(module, "quant_method", None)
+ if quant_method is not None and not hasattr(module, "scheme"):
+ if hasattr(quant_method, "process_weights_after_loading"):
+ # Skip KV cache quantization methods
+ if "KVCache" in quant_method.__class__.__name__:
+ continue
+ quant_method.process_weights_after_loading(module)
+ moe_count += 1
+
+ logger.debug(f"Processed {dense_count} dense layers, {moe_count} MoE layers")
+ return dense_count + moe_count
+
+
+__all__ = [
+ "apply_qat_patches",
+ "prepare_qat_for_load_weights",
+ "manual_process_weights_after_loading",
+]
diff --git a/verl/utils/ray_utils.py b/verl/utils/ray_utils.py
index 5ba20649365..eff3d91085f 100644
--- a/verl/utils/ray_utils.py
+++ b/verl/utils/ray_utils.py
@@ -97,9 +97,13 @@ def get_event_loop():
def auto_await(func):
"""Auto await a coroutine function.
- If the function is called in an async context (with a running event loop),
- it will return the coroutine object. Otherwise, it will block the current thread
- and run the coroutine until completion.
+ Handles three cases:
+ 1. When the decorated function is called with await: returns the coroutine
+ so the caller can await it.
+ 2. When called directly and there is no running event loop: runs the
+ coroutine with asyncio.run() and returns the result.
+ 3. When called directly and the event loop is already running: runs the
+ coroutine (e.g. in a thread pool to avoid deadlock) and returns the result.
"""
@functools.wraps(func)
@@ -114,9 +118,22 @@ def wrapper(*args, **kwargs):
except RuntimeError:
loop = None
- if loop and loop.is_running():
- return coro
- else:
+ # Case 1: No running loop -> run with asyncio.run()
+ if loop is None:
return asyncio.run(coro)
+ # Case 2: Running loop -> return coro if caller will await
+ caller_frame = inspect.currentframe()
+ if caller_frame is not None:
+ caller_frame = caller_frame.f_back
+ caller_is_async = caller_frame is not None and (caller_frame.f_code.co_flags & inspect.CO_COROUTINE) != 0
+ if caller_is_async:
+ return coro
+
+ # Case 3: Running loop -> run coro in thread pool
+ # (cannot block the loop thread without deadlock)
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
+ future = pool.submit(asyncio.run, coro)
+ return future.result()
+
return wrapper
diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py
index 46f82240448..51097f50a51 100644
--- a/verl/utils/seqlen_balancing.py
+++ b/verl/utils/seqlen_balancing.py
@@ -388,7 +388,7 @@ def rearrange_micro_batches(
if min_num_micro_batch is not None:
# used to support pp
num_micro_batches = max(min_num_micro_batch, num_micro_batches)
- if dist.is_initialized() and same_micro_num_in_dp:
+ if dist.is_initialized() and same_micro_num_in_dp and dp_group is not None:
num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name())
dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)
num_micro_batches = num_micro_batches.cpu().item()
diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py
index 2802e3642f1..8666bec2d16 100644
--- a/verl/utils/torch_functional.py
+++ b/verl/utils/torch_functional.py
@@ -710,6 +710,7 @@ def get_cosine_schedule_with_warmup(
num_cycles: float = 0.5,
last_epoch: int = -1,
init_lr_ratio: float = None,
+ zero_indexed_step: bool = True,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
@@ -731,6 +732,9 @@ def get_cosine_schedule_with_warmup(
The index of the last epoch when resuming training.
init_lr_ratio (:obj:`float`, `optional`, defaults to None):
The initial lr ratio w.r.t the maximum.
+ zero_indexed_step (:obj:`bool`, `optional`, defaults to True):
+ Whether the LR schedule uses 0-indexed steps. If True (default), step counting starts at 0.
+ If False (used by torchtitan), step counting starts at 1.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
@@ -743,6 +747,8 @@ def get_cosine_schedule_with_warmup(
assert init_lr_ratio >= 0 and init_lr_ratio <= 1.0
def lr_lambda(current_step):
+ if not zero_indexed_step:
+ current_step += 1
if current_step < num_warmup_steps:
return init_lr_ratio + (1.0 - init_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps)))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py
deleted file mode 100644
index 6014f4bc03e..00000000000
--- a/verl/utils/transferqueue_utils.py
+++ /dev/null
@@ -1,328 +0,0 @@
-# Copyright 2025 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-import functools
-import inspect
-import logging
-import os
-import threading
-from functools import wraps
-from typing import TYPE_CHECKING, Any, Callable
-
-if TYPE_CHECKING:
- from verl.single_controller.base.decorator import Dispatch
-
-from tensordict import TensorDict
-
-try:
- from transfer_queue import (
- AsyncTransferQueueClient,
- BatchMeta,
- TransferQueueClient,
- )
-
-except ImportError:
- # TODO: Use a hacky workaround for ImportError since
- # transfer_queue isn't a default verl dependency.
- class BatchMeta:
- pass
-
-
-from verl.protocol import DataProto
-
-logger = logging.getLogger(__name__)
-logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
-
-_TRANSFER_QUEUE_CLIENT = None
-
-is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False)
-
-
-def create_transferqueue_client(
- client_id: str,
- config,
- sync: bool = False,
-) -> "AsyncTransferQueueClient | TransferQueueClient":
- global _TRANSFER_QUEUE_CLIENT
- if _TRANSFER_QUEUE_CLIENT is None:
- if sync:
- _TRANSFER_QUEUE_CLIENT = TransferQueueClient(client_id, config.controller_info)
- else:
- _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, config.controller_info)
- _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=config.storage_backend, config=config)
-
- return _TRANSFER_QUEUE_CLIENT
-
-
-def get_transferqueue_client() -> "AsyncTransferQueueClient | TransferQueueClient":
- return _TRANSFER_QUEUE_CLIENT
-
-
-# TODO (TQ): verl will make all actor async, so this can be cleanup later.
-def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any:
- # Use a temporary event loop in a new thread because event
- # loop may already exist in server mode
- tmp_event_loop = asyncio.new_event_loop()
- thread = threading.Thread(
- target=tmp_event_loop.run_forever,
- name="batchmeta dataproto converter",
- daemon=True,
- )
-
- def run_coroutine(coroutine):
- if not thread.is_alive():
- thread.start()
- future = asyncio.run_coroutine_threadsafe(coroutine, tmp_event_loop)
- return future.result()
-
- async def stop_loop():
- tmp_event_loop.stop()
-
- try:
- return run_coroutine(async_func(*args, **kwargs))
- finally:
- if thread.is_alive():
- asyncio.run_coroutine_threadsafe(stop_loop(), tmp_event_loop)
- thread.join()
-
-
-def _find_batchmeta(*args, **kwargs):
- for arg in args:
- if isinstance(arg, BatchMeta):
- return arg
- for v in kwargs.values():
- if isinstance(v, BatchMeta):
- return v
- return None
-
-
-async def _async_batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto:
- if batchmeta.samples == [] or batchmeta.samples is None:
- return DataProto(
- batch=TensorDict({}, batch_size=(0,)),
- non_tensor_batch={},
- meta_info=batchmeta.extra_info.copy(),
- )
-
- tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta)
- return DataProto.from_tensordict(tensordict, meta_info=batchmeta.extra_info.copy())
-
-
-def _batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto:
- return _run_async_in_temp_loop(_async_batchmeta_to_dataproto, batchmeta)
-
-
-async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", func_name=None) -> "BatchMeta":
- pid = os.getpid()
-
- for k, v in output.meta_info.items():
- batchmeta.set_extra_info(k, v)
-
- if len(output) > 0:
- tensordict = output.to_tensordict()
- # pop meta_info
- for key in output.meta_info.keys():
- tensordict.pop(key)
-
- logger.info(
- f"Task {func_name} (pid={pid}) putting output data to TransferQueue with "
- f"batch_size={tensordict.batch_size},\n"
- f"tensordict keys={list(tensordict.keys())}"
- )
-
- updated_batch_meta = await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta)
- return updated_batch_meta
- else:
- return batchmeta
-
-
-def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", func_name=None) -> "BatchMeta":
- updated_batch_meta = _run_async_in_temp_loop(_async_update_batchmeta_with_output, output, batchmeta, func_name)
- return updated_batch_meta
-
-
-def _compute_need_collect(dispatch_mode: "dict | Dispatch", args: list) -> bool:
- """Compute whether data collection is needed for the current worker.
-
- This function determines whether the current worker should collect data based on
- the dispatch mode configuration and worker parameters. It's used to optimize
- distributed data collection by ensuring only the appropriate rank collects data.
-
- Args:
- dispatch_mode: Controls data collection logic for the current worker. Can be None,
- a Dispatch instance, or a dict with 'collect_fn' key. If None or Dispatch,
- always returns True (current worker should collect). If dict, checks
- collect_fn for lazy compute optimization.
- args: List of arguments passed to the function. Should contain a Worker instance
- as the first argument when using lazy compute mode.
-
- Returns:
- bool: True if data collection is needed, False otherwise.
-
- Note:
- Only checks worker attributes when dispatch_mode is a dict with 'collect_fn',
- the collect_fn is 'collect_lazy_compute_data_proto', and args[0] is a Worker.
- Otherwise, returns True. For the lazy compute case, checks the worker's
- data parallel rank for the mesh specified in collect_fn.args[0] to determine
- if this worker should collect data.
- """
- from verl.single_controller.base.decorator import Dispatch
- from verl.single_controller.base.worker import Worker
-
- if dispatch_mode is None or isinstance(dispatch_mode, Dispatch):
- return True
-
- assert "collect_fn" in dispatch_mode.keys(), "collect_fn should be in dispatch_mode."
-
- collect_fn = dispatch_mode["collect_fn"]
-
- # Check if collect_fn is a functools.partial and handle gracefully
- if isinstance(collect_fn, functools.partial):
- collect_fn_name = collect_fn.func.__name__
- if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker):
- return True
-
- collect_mesh_name = collect_fn.args[0] if collect_fn.args else None
- if collect_mesh_name is None:
- return True
-
- return args[0].query_collect_info(collect_mesh_name)
- else:
- # If collect_fn is not a partial, we can't extract mesh_name information
- # Fall back to default behavior (collect data)
- return True
-
-
-def _postprocess_common(output, put_data, need_collect):
- """Common post-processing logic for function outputs in TransferQueue bridge.
-
- This function handles the final return value based on whether data should be
- put into storage (put_data) and whether collection is needed (need_collect).
- It ensures proper return types based on the execution context.
-
- Args:
- output: The original output from the decorated function. Can be any type.
- put_data: bool, indicating whether the output should be put into TransferQueue.
- If True, output will be put to TQ and return the corresponding BatchMeta;
- if False, output will not be put into TQ.
- need_collect: bool, indicating whether this process needs to collect data.
- If False, the output will be replaced by an empty BatchMeta or DataProto
- to avoid redundant communication.
-
- Returns:
- - BatchMeta.empty(): When put_data=True but need_collect=False, indicating
- no data should be stored but BatchMeta structure is expected.
- - DataProto(): When put_data=False, need_collect=False, and output is DataProto,
- returning an empty DataProto.
- - output: In all other cases, returns the original output unchanged.
-
- Note:
- This function is used in the tqbridge decorator to normalize return values
- across different execution paths and avoid redundant data operations in
- distributed scenarios.
- """
- if put_data and not need_collect:
- return BatchMeta.empty()
- elif not put_data and not need_collect and isinstance(output, DataProto):
- return DataProto()
- else:
- return output
-
-
-def tqbridge(dispatch_mode: "dict | Dispatch" = None, put_data: bool = True):
- """Creates a decorator for bridging BatchMeta and DataProto.
-
- This decorator automatically handles conversions between `BatchMeta` and
- `DataProto` in function parameters, and decides whether to sync function
- output back to `BatchMeta` based on configuration(`put_data`). It supports
- both synchronous and asynchronous functions (async def), and can control
- whether to enable enhanced logic via the global `HAS_TQ` variable (when disabled,
- simply calls the original function as-is).
-
- Args:
- dispatch_mode: Controls data collection behavior for the current worker. Passed to
- _compute_need_collect to determine if current worker should collect data.
- If None, _compute_need_collect will return True to fallback default logics.
- put_data: Whether put the DataProto into Storage after func return.
- If True, after function execution, the output result will be
- updated to `BatchMeta` and `BatchMeta` will be returned;
- If False, the function output result will be returned directly.
- Defaults to True.
-
- Returns:
- A decorator function used to decorate target functions (synchronous or asynchronous).
- """
-
- def decorator(func):
- pid = os.getpid()
-
- @wraps(func)
- def inner(*args, **kwargs):
- batchmeta = _find_batchmeta(*args, **kwargs)
- if batchmeta is None:
- return func(*args, **kwargs)
- else:
- logger.info(
- f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, "
- f"global_idx={batchmeta.global_indexes}"
- )
- args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]
- kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()}
- output = func(*args, **kwargs)
- need_collect = _compute_need_collect(dispatch_mode, args)
- if put_data and need_collect:
- updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__)
- return updated_batch_meta
- return _postprocess_common(output, put_data, need_collect)
-
- @wraps(func)
- async def async_inner(*args, **kwargs):
- batchmeta = _find_batchmeta(*args, **kwargs)
- if batchmeta is None:
- return await func(*args, **kwargs)
- else:
- logger.info(
- f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, "
- f"global_idx={batchmeta.global_indexes}"
- )
- args = [await _async_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]
- kwargs = {
- k: await _async_batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v
- for k, v in kwargs.items()
- }
- output = await func(*args, **kwargs)
- need_collect = _compute_need_collect(dispatch_mode, args)
- if put_data and need_collect:
- updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__)
- return updated_batchmeta
- return _postprocess_common(output, put_data, need_collect)
-
- @wraps(func)
- def dummy_inner(*args, **kwargs):
- output = func(*args, **kwargs)
- return output
-
- @wraps(func)
- async def dummy_async_inner(*args, **kwargs):
- output = await func(*args, **kwargs)
- return output
-
- wrapper_inner = inner if is_transferqueue_enabled else dummy_inner
- wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner
-
- wrapper = wrapper_async_inner if inspect.iscoroutinefunction(func) else wrapper_inner
- return wrapper
-
- return decorator
diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py
index d524f0e2ba1..c13ffa09c2b 100644
--- a/verl/workers/actor/dp_actor.py
+++ b/verl/workers/actor/dp_actor.py
@@ -412,6 +412,13 @@ def _optimizer_step(self):
self.actor_optimizer.zero_grad()
else:
self.actor_optimizer.step()
+
+ # Clear cached weight scales for QAT (weights changed)
+ if getattr(self.actor_module, "_qat_fuse_enabled", False):
+ from verl.utils.qat import invalidate_all_scales
+
+ invalidate_all_scales(self.actor_module)
+
return grad_norm
@GPUMemoryLogger(role="dp actor", logger=logger)
diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py
index 7fdaa6e9811..f4a697866ab 100644
--- a/verl/workers/actor/megatron_actor.py
+++ b/verl/workers/actor/megatron_actor.py
@@ -139,6 +139,11 @@ def __init__(
assert self.mtp_config.enable, "MTP requires mtp_config.enable to be True"
self.use_fused_kernels = self.config.get("use_fused_kernels", False)
+ if getattr(self.mtp_config, "enable", False) and self.use_fused_kernels:
+ self.use_fused_kernels = False
+ logger.warning_once(
+ "MTP is not compatible with fused kernels for now. Automatically disable use_fused_kernels."
+ )
if self.use_fused_kernels and not getattr(self.config, "overlap_moe_expert_parallel_comm", False):
# do not patch if overlap_moe_expert_parallel_comm is enabled
logger.warning_once(
@@ -434,6 +439,7 @@ def forward_backward_batch(
temperature = data.meta_info["temperature"]
if use_dynamic_bsz:
assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True"
+ dp_group = mpu.get_data_parallel_group()
vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
if vpp_size is not None and vpp_size > 1:
microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage
@@ -441,13 +447,16 @@ def forward_backward_batch(
batch=mini_batch.batch,
num_batches_divided_by=microbatch_group_size_per_vp_stage,
max_token_len=max_token_len,
+ dp_group=dp_group,
)
assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (
f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage "
f"{microbatch_group_size_per_vp_stage} for megatron backend"
)
else:
- micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)
+ micro_batches, indices = rearrange_micro_batches(
+ batch=mini_batch.batch, max_token_len=max_token_len, dp_group=dp_group
+ )
total_seqlen = max_token_len
else:
assert micro_batch_size is not None, (
diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py
index bcf05e8f2d2..66071e8ec20 100644
--- a/verl/workers/config/actor.py
+++ b/verl/workers/config/actor.py
@@ -18,10 +18,11 @@
from omegaconf import MISSING
from verl.base_config import BaseConfig
-from verl.trainer.config import CheckpointConfig
+from verl.trainer.config import CheckpointConfig, RolloutCorrectionConfig
from verl.utils.profiler.config import ProfilerConfig
+from verl.utils.qat import QATConfig
-from .engine import FSDPEngineConfig, McoreEngineConfig, VeOmniEngineConfig
+from .engine import FSDPEngineConfig, McoreEngineConfig, TorchtitanEngineConfig, VeOmniEngineConfig
from .model import HFModelConfig
from .optimizer import OptimizerConfig
@@ -32,6 +33,8 @@
"FSDPActorConfig",
"McoreActorConfig",
"VeOmniActorConfig",
+ "QATConfig",
+ "TorchTitanActorConfig",
]
@@ -77,6 +80,7 @@ class PolicyLossConfig(BaseConfig):
clip_cov_ub (float): Upper bound for clip-cov loss.
kl_cov_ratio (float): Ratio of tokens to be applied KL penalty for kl-cov loss.
ppo_kl_coef (float): KL divergence penalty coefficient.
+ rollout_correction (RolloutCorrectionConfig): Configuration for rollout correction.
"""
loss_mode: str = "vanilla"
@@ -85,6 +89,7 @@ class PolicyLossConfig(BaseConfig):
clip_cov_ub: float = 5.0
kl_cov_ratio: float = 0.0002
ppo_kl_coef: float = 0.1
+ rollout_correction: RolloutCorrectionConfig = field(default_factory=RolloutCorrectionConfig)
@dataclass
@@ -294,6 +299,7 @@ class FSDPActorConfig(ActorConfig):
use_rollout_log_probs: bool = False
calculate_sum_pi_squared: bool = False
sum_pi_squared_checkpointing: bool = False
+ qat: QATConfig = field(default_factory=QATConfig)
def __post_init__(self):
"""Validate FSDP actor configuration parameters."""
@@ -336,3 +342,27 @@ def __post_init__(self):
"""Validate VeOmni actor configuration parameters."""
super().__post_init__()
self.engine = self.veomni
+
+
+@dataclass
+class TorchTitanActorConfig(ActorConfig):
+ """Configuration for TorchTitan actor models.
+
+ The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
+
+ Args:
+ strategy (str): Training strategy set to 'torchtitan' for TorchTitan parallelism.
+ torchtitan (TorchtitanEngineConfig): Configuration for TorchTitan engine settings.
+ use_remove_padding (bool): Whether to remove padding tokens in inputs during training
+ use_rollout_log_probs (bool): Whether to use log probabilities from rollout engine
+ """
+
+ strategy: str = "torchtitan"
+ torchtitan: TorchtitanEngineConfig = field(default_factory=TorchtitanEngineConfig)
+ use_remove_padding: bool = False
+ use_rollout_log_probs: bool = False
+
+ def __post_init__(self):
+ """Validate TorchTitan actor configuration parameters."""
+ super().__post_init__()
+ self.engine = self.torchtitan
diff --git a/verl/workers/config/critic.py b/verl/workers/config/critic.py
index c347b54e754..caca5bac6ac 100644
--- a/verl/workers/config/critic.py
+++ b/verl/workers/config/critic.py
@@ -22,11 +22,11 @@
from verl.trainer.config import BaseModelConfig, CheckpointConfig
from verl.utils.profiler import ProfilerConfig
-from .engine import FSDPEngineConfig, McoreEngineConfig
+from .engine import FSDPEngineConfig, McoreEngineConfig, TorchtitanEngineConfig
from .model import HFModelConfig
from .optimizer import OptimizerConfig
-__all__ = ["CriticConfig", "FSDPCriticConfig", "McoreCriticConfig", "FSDPCriticModelCfg"]
+__all__ = ["CriticConfig", "FSDPCriticConfig", "McoreCriticConfig", "TorchTitanCriticConfig", "FSDPCriticModelCfg"]
@dataclass
@@ -224,6 +224,26 @@ def validate(self, n_gpus: int, train_batch_size: int):
)
+@dataclass
+class TorchTitanCriticConfig(CriticConfig):
+ """Configuration for TorchTitan-based critic model training.
+
+ The inheritance from CriticConfig provides all base critic configuration plus TorchTitan-specific settings.
+
+ Args:
+ strategy (str): Training strategy set to 'torchtitan' for TorchTitan parallelism.
+ torchtitan (TorchtitanEngineConfig): Configuration for TorchTitan engine settings.
+ """
+
+ strategy: str = "torchtitan"
+ torchtitan: TorchtitanEngineConfig = field(default_factory=TorchtitanEngineConfig)
+
+ def __post_init__(self):
+ """Validate TorchTitan critic configuration parameters."""
+ super().__post_init__()
+ self.engine = self.torchtitan
+
+
@dataclass
class FSDPCriticModelCfg(BaseModelConfig):
"""FSDP-enabled critic model configuration.
diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py
index feb559b374c..6caa57d3190 100644
--- a/verl/workers/config/engine.py
+++ b/verl/workers/config/engine.py
@@ -27,6 +27,7 @@
"FSDPEngineConfig",
"McoreEngineConfig",
"TrainingWorkerConfig",
+ "TorchtitanEngineConfig",
"VeOmniEngineConfig",
"EngineConfig",
"EngineRouterReplayConfig",
@@ -73,6 +74,8 @@ class EngineConfig(BaseConfig):
"infer_micro_batch_size_per_gpu",
"use_fused_kernels",
"use_remove_padding",
+ "forward_only",
+ "param_offload",
}
# whether to offload param
param_offload: bool = False
@@ -235,14 +238,9 @@ class VeOmniEngineConfig(EngineConfig):
optimizer_offload (bool): Whether to offload optimizer states to CPU, default False
offload_policy (bool): Whether to offload policy model parameters, default False
reshard_after_forward (bool): Whether to reshard parameters after forward pass, default True
- data_parallel_size (int): FSDP group size, default 1
- data_parallel_replicate_size (int): Data parallel replicate size, default 1
- data_parallel_shard_size (int): Data parallel shard degree, default 1
- tensor_parallel_size (int): Tensor parallel size, default 1
- expert_parallel_size (int): Expert parallel size, default 1
- pipeline_parallel_size (int): Pipeline parallel size, default 1
- context_parallel_size (int): Ring-attn context parallel size, default 1
+ fsdp_size (int): FSDP group size. -1 means use all available GPUs, default -1
ulysses_parallel_size (int): Ulysses sequence parallel size, default 1
+ expert_parallel_size (int): Expert parallel size, default 1
init_device (str): Device to initialize model weights.
1. `cpu`: Init parameters on CPU in rank0 only.
2. `cuda`: Init parameters on GPU.
@@ -291,14 +289,9 @@ class VeOmniEngineConfig(EngineConfig):
use_torch_compile: bool = True
entropy_checkpointing: bool = False
strategy: str = "veomni"
- data_parallel_size: int = 1
- data_parallel_replicate_size: int = 1
- data_parallel_shard_size: int = 1
- tensor_parallel_size: int = 1
- expert_parallel_size: int = 1
- pipeline_parallel_size: int = 1
- context_parallel_size: int = 1
+ fsdp_size: int = -1
ulysses_parallel_size: int = 1
+ expert_parallel_size: int = 1
seed: int = 42
full_determinism: bool = False
mixed_precision: bool = False
@@ -319,6 +312,66 @@ def __post_init__(self):
assert self.strategy in ["veomni"], f"strategy {self.strategy} not supported"
+@dataclass
+class TorchtitanEngineConfig(EngineConfig):
+ """Configuration for Torchtitan.
+
+ The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
+
+ Args:
+ wrap_policy (Dict[str, Any]): Configuration for FSDP wrap policy.
+ reshard_after_forward (Literal["default", "always", "never"]): The policy for applying
+ `reshard_after_forward` within an FSDP setup, default "default"
+ forward_prefetch (bool): Whether to prefetch parameters for next forward pass, default False
+ use_orig_params (bool): Whether to use original parameters when initialize FSDP, default False
+ mixed_precision (bool): Mixed precision configuration for FSDP, default False
+ offload_policy (bool): Whether to offload policy model parameters, default False
+ data_parallel_size (int): Data parallel group size, default 1
+ data_parallel_replicate_size (int): Data parallel replicate size, default 1
+ data_parallel_shard_size (int): Data parallel shard degree, default 1
+ tensor_parallel_size (int): Tensor parallel size, default 1
+ expert_parallel_size (int): Expert parallel size, default 1
+ expert_tensor_parallel_size (int): Expert tensor parallel size, default 1
+ pipeline_parallel_size (int): Pipeline parallel size, default 1
+ context_parallel_size (int): Context parallel size, default 1
+ attn_type (str): Attention type for torchtitan's model (e.g., "sdpa", "flex", "varlen"),
+ default "flex"
+ strategy (str): Strategy to use for distributed training, default "torchtitan"
+ seed (int): Random seed for reproducibility.
+ full_determinism (bool): If true, enable_full_determinism is called to ensure reproducible results
+ in distributed training. Important: this will negatively impact performance, so only use it for
+ debugging.
+
+ """
+
+ wrap_policy: dict[str, Any] = field(default_factory=dict)
+ reshard_after_forward: Literal["default", "always", "never"] = "default"
+ forward_prefetch: bool = False
+ use_orig_params: bool = False
+ mixed_precision: bool = False
+ offload_policy: bool = False
+ use_torch_compile: bool = True
+ entropy_from_logits_with_chunking: bool = False
+ entropy_checkpointing: bool = False
+ data_parallel_size: int = 1
+ data_parallel_replicate_size: int = 1
+ data_parallel_shard_size: int = 1
+ tensor_parallel_size: int = 1
+ expert_parallel_size: int = 1
+ expert_tensor_parallel_size: int = 1
+ pipeline_parallel_size: int = 1
+ context_parallel_size: int = 1
+ attn_type: str = "flex"
+ max_seq_len: Optional[int] = None
+ strategy: str = "torchtitan"
+ seed: int = 42
+ full_determinism: bool = False
+
+ def __post_init__(self):
+ super().__post_init__()
+ assert self.strategy in ["torchtitan"], f"strategy {self.strategy} not supported"
+
+
@dataclass
class TrainingWorkerConfig(BaseConfig):
model_type: str = None # model type (language_model/value_model)
diff --git a/verl/workers/config/model.py b/verl/workers/config/model.py
index 1aa2afc0843..9205a99f038 100644
--- a/verl/workers/config/model.py
+++ b/verl/workers/config/model.py
@@ -119,7 +119,7 @@ class HFModelConfig(BaseConfig):
# fsdp lora related. We may setup a separate config later
lora_rank: int = 0
lora_alpha: int = 16
- target_modules: Optional[str] = "all-linear"
+ target_modules: Optional[Any] = "all-linear" # allow both "all-linear" and ["q_proj","k_proj"]
target_parameters: Optional[list[str]] = None # for lora adapter on nn.Parameter
exclude_modules: Optional[str] = None
@@ -204,5 +204,19 @@ def __post_init__(self):
if getattr(self.hf_config, "model_type", None) == "kimi_vl":
self.hf_config.text_config.topk_method = "greedy"
+ # Ensure target_modules is a str or list[str] (only if not None)
+ if self.target_modules is not None:
+ if not isinstance(self.target_modules, (str | list)):
+ raise TypeError(
+ "target_modules must be a string or a list of strings, "
+ f"but got {type(self.target_modules).__name__}"
+ )
+ if isinstance(self.target_modules, list):
+ for x in self.target_modules:
+ if not isinstance(x, str):
+ raise TypeError(
+ f"All elements in target_modules list must be strings, but found {type(x).__name__}"
+ )
+
def get_processor(self):
return self.processor if self.processor is not None else self.tokenizer
diff --git a/verl/workers/config/optimizer.py b/verl/workers/config/optimizer.py
index bdb87667c25..b7f05bef518 100644
--- a/verl/workers/config/optimizer.py
+++ b/verl/workers/config/optimizer.py
@@ -19,7 +19,14 @@
from verl.base_config import BaseConfig
-__all__ = ["OptimizerConfig", "FSDPOptimizerConfig", "McoreOptimizerConfig", "build_optimizer", "VeOmniOptimizerConfig"]
+__all__ = [
+ "OptimizerConfig",
+ "FSDPOptimizerConfig",
+ "McoreOptimizerConfig",
+ "build_optimizer",
+ "VeOmniOptimizerConfig",
+ "TorchtitanOptimizerConfig",
+]
@dataclass
@@ -88,6 +95,8 @@ class FSDPOptimizerConfig(OptimizerConfig):
min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule.
lr_scheduler_type (str): LR scheduler type: "constant" or "cosine".
num_cycles (float): Number of cosine cycles in LR schedule.
+ zero_indexed_step (bool): Whether the LR schedule uses 0-indexed steps. If True (default),
+ step counting starts at 0. If False, step counting starts at 1.
"""
_mutable_fields = OptimizerConfig._mutable_fields.copy()
@@ -101,6 +110,7 @@ class FSDPOptimizerConfig(OptimizerConfig):
lr_scheduler_type: str = "constant"
num_cycles: float = 0.5
override_optimizer_config: Optional[dict] = None
+ zero_indexed_step: bool = True
def __post_init__(self):
if self.warmup_style is not None:
@@ -143,6 +153,23 @@ class McoreOptimizerConfig(OptimizerConfig):
override_optimizer_config: Optional[dict] = None
+@dataclass
+class TorchtitanOptimizerConfig(OptimizerConfig):
+ """Torchtitan optimizer configuration extending base OptimizerConfig.
+
+ Args:
+ name (str): Optimizer name; default is "AdamW".
+ eps (float): Epsilon value for AdamW optimizer, default 1e-8.
+ decay_type (str): Weight decay type: "linear", "sqrt", or "cosine".
+ min_lr_factor (float): Minimum learning rate factor.
+ """
+
+ name: str = "AdamW"
+ eps: float = 1e-8
+ decay_type: str = "linear"
+ min_lr_factor: float = 0.0
+
+
def build_optimizer(parameters, config: FSDPOptimizerConfig):
"""Build an optimizer based on the configuration.
diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py
index 3b4e7c121a3..d1d5c8f1768 100644
--- a/verl/workers/config/rollout.py
+++ b/verl/workers/config/rollout.py
@@ -80,6 +80,8 @@ class AgentLoopConfig(BaseConfig):
@dataclass
class TraceConfig(BaseConfig):
+ project_name: Optional[str] = None
+ experiment_name: Optional[str] = None
backend: Optional[str] = None
token2text: bool = False
max_samples_per_step_per_worker: Optional[int] = None
@@ -125,7 +127,7 @@ class CheckpointEngineConfig(BaseConfig):
"""
# Backend for checkpoint engine: naive, nccl, nixl, hccl
- backend: Optional[str] = MISSING
+ backend: Optional[str] = "naive"
# Bucket size in MB to transfer multiple weights at one time
update_weights_bucket_megabytes: int = 2048
# Additional keyword arguments for checkpoint engine
@@ -138,6 +140,8 @@ class RolloutConfig(BaseConfig):
name: Optional[str] = MISSING
mode: str = "async"
+ nnodes: int = 0
+ n_gpus_per_node: int = 8
temperature: float = 1.0
top_k: int = -1
@@ -238,6 +242,8 @@ class RolloutConfig(BaseConfig):
mtp: MtpConfig = field(default_factory=MtpConfig)
+ qat: Optional[dict] = None
+
def __post_init__(self):
"""Validate the rollout config"""
# Deprecation warning for mode field - only async mode is supported
diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py
index ecc166cd495..af87cb74c09 100644
--- a/verl/workers/critic/megatron_critic.py
+++ b/verl/workers/critic/megatron_critic.py
@@ -181,6 +181,7 @@ def forward_backward_batch(
indices = None
if use_dynamic_bsz:
assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True"
+ dp_group = mpu.get_data_parallel_group()
vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
if vpp_size is not None and vpp_size > 1:
microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage
@@ -188,13 +189,16 @@ def forward_backward_batch(
batch=mini_batch.batch,
num_batches_divided_by=microbatch_group_size_per_vp_stage,
max_token_len=max_token_len,
+ dp_group=dp_group,
)
assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (
f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage "
f"{microbatch_group_size_per_vp_stage} for megatron backend"
)
else:
- micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)
+ micro_batches, indices = rearrange_micro_batches(
+ batch=mini_batch.batch, max_token_len=max_token_len, dp_group=dp_group
+ )
total_seqlen = max_token_len
else:
assert micro_batch_size is not None, (
diff --git a/verl/workers/engine/__init__.py b/verl/workers/engine/__init__.py
index 7b8be1002c0..8f01080fdcb 100644
--- a/verl/workers/engine/__init__.py
+++ b/verl/workers/engine/__init__.py
@@ -21,6 +21,14 @@
"FSDPEngineWithLMHead",
]
+try:
+ from .torchtitan import TorchTitanEngine, TorchTitanEngineWithLMHead
+
+ __all__ += ["TorchTitanEngine", "TorchTitanEngineWithLMHead"]
+except ImportError:
+ TorchTitanEngine = None
+ TorchTitanEngineWithLMHead = None
+
try:
from .veomni import VeOmniEngine, VeOmniEngineWithLMHead
diff --git a/verl/workers/engine/fsdp/transformer_impl.py b/verl/workers/engine/fsdp/transformer_impl.py
index cba8909ff64..dbe9eb2f4e1 100644
--- a/verl/workers/engine/fsdp/transformer_impl.py
+++ b/verl/workers/engine/fsdp/transformer_impl.py
@@ -62,9 +62,14 @@
from verl.utils.model import convert_weight_keys, extract_multi_modal_inputs
from verl.utils.py_functional import convert_to_regular_types
from verl.utils.torch_functional import logprobs_from_logits
-from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
+from verl.utils.ulysses import (
+ gather_outputs_and_unpad,
+ get_ulysses_sequence_parallel_group,
+ set_ulysses_sequence_parallel_group,
+ ulysses_pad,
+ ulysses_pad_and_slice_inputs,
+)
from verl.workers.config import FSDPEngineConfig, FSDPOptimizerConfig, HFModelConfig
-from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from ..base import BaseEngine, BaseEngineCtx, EngineRegistry
from ..utils import enable_full_determinism, postprocess_batch_func, prepare_micro_batches
@@ -190,14 +195,15 @@ def _init_device_mesh(self):
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
+ self.ulysses_parallel_group = None
self.ulysses_sequence_parallel_size = self.engine_config.ulysses_sequence_parallel_size
dp_size = self.get_data_parallel_size()
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
device_name, mesh_shape=(dp_size, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
)
+ self.ulysses_parallel_group = self.ulysses_device_mesh["sp"].get_group()
- self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
def _build_module(self):
@@ -406,6 +412,7 @@ def _build_lr_scheduler(self, optimizer):
lr_scheduler_type = optim_config.lr_scheduler_type
min_lr_ratio = optim_config.min_lr_ratio
num_cycles = optim_config.num_cycles
+ zero_indexed_step = optim_config.zero_indexed_step
if num_warmup_steps <= 0:
num_warmup_steps_ratio = optim_config.lr_warmup_steps_ratio
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
@@ -422,6 +429,7 @@ def _build_lr_scheduler(self, optimizer):
num_training_steps=total_steps,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
+ zero_indexed_step=zero_indexed_step,
)
else:
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
@@ -707,12 +715,13 @@ def __init__(self, engine: FSDPEngine, **kwargs):
def __enter__(self):
assert isinstance(self.engine, FSDPEngine)
super().__enter__()
- self.engine.ulysses_sharding_manager.__enter__()
+ self.prev_sp_group = get_ulysses_sequence_parallel_group()
+ set_ulysses_sequence_parallel_group(self.engine.ulysses_parallel_group)
self.engine.module.eval()
def __exit__(self, exc_type, exc_value, traceback):
assert isinstance(self.engine, FSDPEngine)
- self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)
+ set_ulysses_sequence_parallel_group(self.prev_sp_group)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
@@ -732,12 +741,13 @@ def __init__(self, engine: FSDPEngine, **kwargs):
def __enter__(self):
assert isinstance(self.engine, FSDPEngine)
super().__enter__()
- self.engine.ulysses_sharding_manager.__enter__()
+ self.prev_sp_group = get_ulysses_sequence_parallel_group()
+ set_ulysses_sequence_parallel_group(self.engine.ulysses_parallel_group)
self.engine.module.train()
def __exit__(self, exc_type, exc_value, traceback):
assert isinstance(self.engine, FSDPEngine)
- self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)
+ set_ulysses_sequence_parallel_group(self.prev_sp_group)
self.engine.optimizer_zero_grad()
super().__exit__(exc_type, exc_value, traceback)
diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py
index 0e3f7ff6a29..5cb0824a96b 100644
--- a/verl/workers/engine/megatron/transformer_impl.py
+++ b/verl/workers/engine/megatron/transformer_impl.py
@@ -152,6 +152,10 @@ def _build_tf_config(self):
# In case of invalid overrides, we need to make sure some critical params are set correctly
provider.params_dtype = self.param_dtype
+ # Ensure dtype settings propagate to Megatron-Bridge/TE
+ provider.fp16 = self.param_dtype == torch.float16
+ provider.bf16 = self.param_dtype == torch.bfloat16
+
# Pass distributed info
provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size
provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size
@@ -315,6 +319,8 @@ def initialize(self):
if self.engine_config.forward_only:
self.optimizer = None
self.lr_scheduler = None
+ self.to(device="cpu", model=self._is_offload_param, optimizer=False, grad=False)
+ log_gpu_memory_usage("After offload model during init (forward_only)", logger=logger)
return
self.optimizer = self._build_optimizer()
@@ -598,12 +604,14 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
return {}
def get_per_tensor_param(self, base_sync_done=False, **kwargs):
- load_megatron_model_to_gpu(self.module, load_grad=False)
peft_config = None
non_merge_lora_sync = self.peft_cls is not None and not self.model_config.lora.get("merge", False)
+ adapter_only = base_sync_done and non_merge_lora_sync
+ # when lora adapter only, we only load adapter weights when base sync is done, otherwise load all weights
+ load_megatron_model_to_gpu(self.module, load_grad=False, load_frozen_params=not adapter_only)
if self.vanilla_bridge:
per_tensor_param = self.bridge.export_weights(self.module)
- elif base_sync_done and non_merge_lora_sync:
+ elif adapter_only:
# Only export adapter weights
peft_config = build_peft_config_for_vllm(self.model_config.lora)
per_tensor_param = self.bridge.export_adapter_weights(self.module)
diff --git a/verl/workers/engine/torchtitan/__init__.py b/verl/workers/engine/torchtitan/__init__.py
new file mode 100644
index 00000000000..345757277af
--- /dev/null
+++ b/verl/workers/engine/torchtitan/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .transformer_impl import TorchTitanEngine, TorchTitanEngineWithLMHead
+
+__all__ = ["TorchTitanEngine", "TorchTitanEngineWithLMHead"]
diff --git a/verl/workers/engine/torchtitan/transformer_impl.py b/verl/workers/engine/torchtitan/transformer_impl.py
new file mode 100644
index 00000000000..002ea20e4ff
--- /dev/null
+++ b/verl/workers/engine/torchtitan/transformer_impl.py
@@ -0,0 +1,736 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+The concrete Engine implementation using PyTorch TorchTitan parallelism (FSDP2 + TP + PP)
+"""
+
+import gc
+import logging
+import os
+import re
+from contextlib import nullcontext
+from typing import Any, Callable, Optional
+
+import torch
+import torch.distributed
+from tensordict import TensorDict
+from torch.distributed.checkpoint.state_dict import get_model_state_dict
+from torch.distributed.tensor import DTensor
+from torchtitan.config.job_config import (
+ Checkpoint,
+ Compile,
+ JobConfig,
+ LRScheduler,
+ Model,
+ Optimizer,
+ Parallelism,
+ Training,
+)
+from torchtitan.distributed import utils as dist_utils
+from torchtitan.distributed.context_parallel import prepare_context_parallel_input
+from torchtitan.distributed.parallel_dims import ParallelDims
+from torchtitan.train import Trainer
+
+import verl.utils.torch_functional as verl_F
+from verl.trainer.config import CheckpointConfig
+from verl.utils import tensordict_utils as tu
+from verl.utils.dataset.dataset_utils import DatasetPadMode
+from verl.utils.debug import log_gpu_memory_usage
+from verl.utils.device import get_device_id, get_device_name
+from verl.utils.fsdp_utils import (
+ load_fsdp_model_to_gpu,
+ load_fsdp_optimizer,
+ offload_fsdp_model_to_cpu,
+ offload_fsdp_optimizer,
+)
+from verl.utils.model import extract_multi_modal_inputs
+from verl.utils.torch_functional import logprobs_from_logits
+from verl.workers.config import HFModelConfig, TorchtitanEngineConfig, TorchtitanOptimizerConfig
+from verl.workers.engine.torchtitan.utils import (
+ derive_torchtitan_name_and_flavor,
+ enable_fsdp_gradient_division,
+ get_attention_masks,
+)
+
+from ..base import BaseEngine, BaseEngineCtx, EngineRegistry
+from ..utils import enable_full_determinism, postprocess_batch_func, prepare_micro_batches
+
+logger = logging.getLogger(__file__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+
+device_name = get_device_name()
+
+
+class TorchTitanEngine(BaseEngine):
+ """
+ Concrete Engine implementation using PyTorch TorchTitan parallelism.
+
+ Supports model sharding with FSDP2, tensor parallelism, activation/optimizer offloading,
+ LoRA, and sequence parallelism following the TorchTitan design.
+ """
+
+ def __init__(
+ self,
+ model_config: HFModelConfig,
+ engine_config: TorchtitanEngineConfig,
+ optimizer_config: TorchtitanOptimizerConfig,
+ checkpoint_config: CheckpointConfig,
+ ):
+ """
+ Initialize the TorchTitanEngine.
+
+ Sets up distributed device meshes for tensor and data parallelism, LoRA, and offload policies.
+
+ Args:
+ model_config: Configuration for HuggingFace model.
+ engine_config: Configuration for FSDP/TorchTitan engine (uses FSDP2).
+ optimizer_config: Configuration for optimizer.
+ checkpoint_config: Configuration for checkpointing.
+ """
+ super().__init__()
+
+ self.model_config = model_config
+ self.engine_config = engine_config
+ self.optimizer_config = optimizer_config
+ self.checkpoint_config = checkpoint_config
+
+ # Disable torchtitan's dataloader since verl has its own data loading
+ # Ideally torchtitan trainer init should not initialize dataloader
+ import torchtitan.protocols.train_spec as train_spec_module
+
+ original_get_train_spec = train_spec_module.get_train_spec
+
+ def _get_train_spec_without_dataloader(model_name):
+ train_spec = original_get_train_spec(model_name)
+ train_spec.build_dataloader_fn = None
+ return train_spec
+
+ train_spec_module.get_train_spec = _get_train_spec_without_dataloader
+
+ # Derive torchtitan model name and flavor from HF config
+ torchtitan_name, torchtitan_flavor = derive_torchtitan_name_and_flavor(self.model_config.hf_config)
+
+ # Get train_spec and directly override model_args before Trainer init
+ train_spec = train_spec_module.get_train_spec(torchtitan_name)
+ model_args = train_spec.model_args.get(torchtitan_flavor)
+ if model_args is not None:
+ if hasattr(model_args, "attn_type"):
+ model_args.attn_type = self.engine_config.attn_type
+
+ model = Model(
+ name=torchtitan_name,
+ flavor=torchtitan_flavor,
+ hf_assets_path=self.model_config.path,
+ )
+ optimizer = Optimizer(
+ name=self.optimizer_config.name,
+ lr=self.optimizer_config.lr,
+ eps=self.optimizer_config.eps,
+ beta1=self.optimizer_config.betas[0],
+ beta2=self.optimizer_config.betas[1],
+ weight_decay=self.optimizer_config.weight_decay,
+ )
+
+ total_steps = self.optimizer_config.total_training_steps
+ lr_warmup_steps = self.optimizer_config.lr_warmup_steps
+ if lr_warmup_steps is None or lr_warmup_steps <= 0:
+ lr_warmup_steps = int(self.optimizer_config.lr_warmup_steps_ratio * total_steps)
+
+ lr_scheduler = LRScheduler(
+ warmup_steps=lr_warmup_steps,
+ decay_type=self.optimizer_config.decay_type,
+ min_lr_factor=self.optimizer_config.min_lr_factor,
+ )
+ parallelism = Parallelism(
+ data_parallel_replicate_degree=self.engine_config.data_parallel_replicate_size,
+ data_parallel_shard_degree=self.engine_config.data_parallel_shard_size,
+ fsdp_reshard_after_forward=self.engine_config.reshard_after_forward,
+ tensor_parallel_degree=self.engine_config.tensor_parallel_size,
+ pipeline_parallel_degree=self.engine_config.pipeline_parallel_size,
+ context_parallel_degree=self.engine_config.context_parallel_size,
+ expert_parallel_degree=self.engine_config.expert_parallel_size,
+ expert_tensor_parallel_degree=self.engine_config.expert_tensor_parallel_size,
+ )
+ checkpoint = Checkpoint(
+ enable=True,
+ initial_load_in_hf=True,
+ initial_load_model_only=True,
+ initial_load_path=model_config.path,
+ )
+ compile = Compile(enable=self.engine_config.use_torch_compile)
+ training_kwargs = {}
+ if self.engine_config.max_seq_len is not None:
+ training_kwargs["seq_len"] = self.engine_config.max_seq_len
+ if self.engine_config.offload_policy or self.engine_config.forward_only:
+ training = Training(enable_cpu_offload=True, **training_kwargs)
+ else:
+ training = Training(**training_kwargs)
+
+ # Construct Torchtitan's JobConfig
+ self.config = JobConfig(
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ parallelism=parallelism,
+ checkpoint=checkpoint,
+ compile=compile,
+ training=training,
+ )
+ self.trainer = Trainer(self.config)
+
+ self._init_device_mesh()
+
+ # Re-enable FSDP's gradient division for verl's loss scaling.
+ # TorchTitan disables gradient division by default (for global token normalization),
+ # but verl's loss function multiplies by dp_size to compensate for gradient averaging.
+ if self.engine_config.data_parallel_shard_size > 1:
+ dp_size = self.get_data_parallel_size()
+ for model_part in self.trainer.model_parts:
+ enable_fsdp_gradient_division(model_part, dp_size)
+
+ if self.engine_config.full_determinism:
+ enable_full_determinism(seed=self.engine_config.seed)
+
+ # set FSDP offload params
+ self._is_offload_param = self.engine_config.param_offload
+ self._is_offload_optimizer = self.engine_config.optimizer_offload
+
+ if self.engine_config.entropy_from_logits_with_chunking:
+ entropy_from_logits = verl_F.entropy_from_logits_with_chunking
+ else:
+ entropy_from_logits = verl_F.entropy_from_logits
+
+ self.compute_entropy_from_logits = (
+ torch.compile(entropy_from_logits, dynamic=True)
+ if self.engine_config.use_torch_compile
+ else entropy_from_logits
+ )
+
+ @property
+ def is_param_offload_enabled(self) -> bool:
+ return self._is_offload_param
+
+ @property
+ def is_optimizer_offload_enabled(self) -> bool:
+ return self._is_offload_optimizer
+
+ def is_mp_src_rank_with_outputs(self):
+ """
+ Whether the current rank is the first rank in model parallel group that contains model outputs
+ """
+ is_collect = True
+ # TP: outputs are on TP rank 0
+ if self.parallel_dims.tp > 1:
+ tp_mesh = self.parallel_dims.get_optional_mesh("tp")
+ is_collect = is_collect and (tp_mesh.get_local_rank() == 0)
+ # PP: outputs are on the last PP rank
+ if self.parallel_dims.pp > 1:
+ pp_mesh = self.parallel_dims.get_optional_mesh("pp")
+ is_collect = is_collect and (pp_mesh.get_local_rank() == self.parallel_dims.pp - 1)
+ # CP: outputs are on CP rank 0
+ if self.parallel_dims.cp > 1:
+ cp_mesh = self.parallel_dims.get_optional_mesh("cp")
+ is_collect = is_collect and (cp_mesh.get_local_rank() == 0)
+ return is_collect
+
+ def initialize(self):
+ """
+ Build the model, optimizer, and learning rate scheduler with TorchTitan parallelism.
+
+ Applies device, dtype, and precision configurations, including mixed precision.
+ Sets up checkpoint manager.
+ """
+ self.module = self.trainer.model_parts
+ self.checkpointer = self.trainer.checkpointer
+ # load initial HF weights
+ self.checkpointer.load()
+
+ if not self.engine_config.forward_only:
+ self.optimizer = self.trainer.optimizers
+ self.lr_scheduler = self.trainer.lr_schedulers
+ else:
+ self.optimizer = None
+ self.lr_scheduler = None
+
+ self.to(
+ device="cpu",
+ model=self._is_offload_param,
+ optimizer=self._is_offload_optimizer,
+ grad=self._is_offload_param,
+ )
+
+ log_gpu_memory_usage("After offload model/optimizer/grad during init", logger=logger)
+
+ def _init_device_mesh(self):
+ """Initialize the device mesh for TorchTitan style parallelism."""
+ world_size = torch.distributed.get_world_size()
+ self.parallel_dims = ParallelDims(
+ dp_shard=self.engine_config.data_parallel_shard_size,
+ dp_replicate=self.engine_config.data_parallel_replicate_size,
+ cp=self.engine_config.context_parallel_size,
+ tp=self.engine_config.tensor_parallel_size,
+ pp=self.engine_config.pipeline_parallel_size,
+ ep=self.engine_config.expert_parallel_size,
+ etp=self.engine_config.expert_tensor_parallel_size,
+ world_size=world_size,
+ )
+ self.device_mesh = self.parallel_dims.build_mesh()
+
+ def train_mode(self, **kwargs):
+ """Return a context manager for training mode."""
+ return EngineTrainModeCtx(self, **kwargs)
+
+ def eval_mode(self, **kwargs):
+ """Return a context manager for evaluation mode."""
+ return EngineEvalModeCtx(self, **kwargs)
+
+ def get_data_parallel_rank(self):
+ mesh = self._get_data_parallel_mesh()
+ return 0 if mesh is None else mesh.get_local_rank()
+
+ def get_data_parallel_size(self):
+ return self.engine_config.data_parallel_shard_size * self.engine_config.data_parallel_replicate_size
+
+ def get_data_parallel_group(self):
+ mesh = self._get_data_parallel_mesh()
+ if mesh is not None:
+ return mesh.get_group()
+ # If world_size == dp_size (e.g. single GPU, or all ranks are DP),
+ # return WORLD so that collective ops in _postprocess_output
+ # (allgather_dict_into_dict, all_reduce) still run and produce the
+ # correct metric aggregation format.
+ if torch.distributed.get_world_size() == self.get_data_parallel_size():
+ return torch.distributed.group.WORLD
+ return None
+
+ def _get_data_parallel_mesh(self):
+ """Get the data parallel mesh, handling hybrid/fully/replicate shard modes."""
+ mesh = self.parallel_dims.get_optional_mesh(["dp_replicate", "fsdp"])
+ if mesh is None:
+ mesh = self.parallel_dims.get_optional_mesh("fsdp")
+ if mesh is None:
+ mesh = self.parallel_dims.get_optional_mesh("dp_replicate")
+ return mesh
+
+ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False):
+ """Perform forward and optionally backward pass on a batch."""
+ tu.assign_non_tensor(data, sp_size=self.engine_config.tensor_parallel_size)
+
+ # Compute num_tokens in global batch for loss normalization
+ batch_num_tokens = data["loss_mask"].sum().to(get_device_id())
+ dp_group = self.get_data_parallel_group()
+ if dp_group is not None:
+ torch.distributed.all_reduce(batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=dp_group)
+ tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item())
+ tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size())
+
+ micro_batches, indices = prepare_micro_batches(
+ data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True
+ )
+
+ output_lst = []
+
+ ctx = torch.no_grad() if forward_only else nullcontext()
+
+ for micro_batch in micro_batches:
+ with ctx:
+ loss, output = self.forward_step(micro_batch, loss_function=loss_function, forward_only=forward_only)
+ if not forward_only:
+ loss.backward()
+ output_lst.append(output)
+
+ return postprocess_batch_func(output_lst=output_lst, indices=indices, data=data)
+
+ def model_forward_step(
+ self,
+ *,
+ inputs: torch.Tensor,
+ extra_inputs: dict[str, torch.Tensor] | None = None,
+ extra_kwargs: dict[str, torch.Tensor] | None = None,
+ ) -> torch.Tensor:
+ """
+ Perform a forward pass through the trainer model without backward.
+ """
+ model_parts = self.module
+ parallel_dims = self.parallel_dims
+
+ if parallel_dims.pp_enabled:
+ raise NotImplementedError(
+ "Pipeline parallelism is not yet supported in model_forward_step. "
+ "This will be implemented in a follow-up PR."
+ )
+ else:
+ # Non-PP forward
+ assert len(model_parts) == 1
+ with self.trainer.train_context():
+ with self.trainer.maybe_enable_amp:
+ pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
+
+ if isinstance(pred, DTensor):
+ pred = pred.full_tensor()
+ return pred
+
+ def forward_step(self, micro_batch: TensorDict, loss_function, forward_only):
+ raise NotImplementedError("forward_step must be implemented in subclass")
+
+ def optimizer_zero_grad(self):
+ """Zero gradients."""
+ self.optimizer.zero_grad()
+
+ def optimizer_step(self):
+ """Perform optimizer step with gradient clipping."""
+ grad_norm = dist_utils.clip_grad_norm_(
+ [p for m in self.module for p in m.parameters()],
+ self.config.training.max_norm,
+ foreach=True,
+ pp_mesh=self.parallel_dims.get_optional_mesh("pp"),
+ ep_enabled=self.parallel_dims.ep_enabled,
+ )
+
+ # if grad_norm is not finite, skip the update
+ if not torch.isfinite(grad_norm):
+ logger.warning(f"grad_norm is not finite: {grad_norm}")
+ self.optimizer.zero_grad()
+ else:
+ self.optimizer.step()
+ return grad_norm.item()
+
+ def lr_scheduler_step(self):
+ """Advance learning rate scheduler."""
+ self.lr_scheduler.step()
+ lr = self.lr_scheduler.schedulers[0].get_last_lr()[0]
+ return lr
+
+ def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
+ """Move model and/or optimizer to CPU or GPU."""
+ super().to(device=device, model=model, optimizer=optimizer, grad=grad)
+
+ if self.engine_config.forward_only:
+ return
+
+ device_name = get_device_name()
+ assert device in (device_name, "cpu")
+ if device == device_name:
+ if model:
+ for module in self.module:
+ load_fsdp_model_to_gpu(module)
+ if optimizer and self.optimizer is not None:
+ load_fsdp_optimizer(self.optimizer, device)
+ gc.collect()
+ elif device == "cpu":
+ if model:
+ for module in self.module:
+ offload_fsdp_model_to_cpu(module)
+ if optimizer and self.optimizer is not None:
+ offload_fsdp_optimizer(self.optimizer)
+ else:
+ raise ValueError(f"Invalid device type: {device}")
+
+ def save_checkpoint(
+ self,
+ local_path: str,
+ hdfs_path: Optional[str] = None,
+ global_step: int = 0,
+ max_ckpt_to_keep: Optional[int] = None,
+ **kwargs,
+ ) -> None:
+ """Save checkpoint."""
+ if self._is_offload_param:
+ for module in self.module:
+ load_fsdp_model_to_gpu(module)
+
+ # Override TorchTitan's folder to use verl's path
+ parent_dir = os.path.dirname(local_path)
+ self.checkpointer.folder = parent_dir
+
+ if max_ckpt_to_keep is not None:
+ self.checkpointer.keep_latest_k = max_ckpt_to_keep
+
+ self.checkpointer.save(curr_step=global_step)
+
+ torch.distributed.barrier()
+ if self._is_offload_param:
+ for module in self.module:
+ offload_fsdp_model_to_cpu(module)
+
+ def load_checkpoint(
+ self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs
+ ) -> None:
+ """Load checkpoint."""
+ if self._is_offload_param:
+ for module in self.module:
+ load_fsdp_model_to_gpu(module)
+
+ # Override TorchTitan's folder to use verl's path
+ parent_dir = os.path.dirname(local_path)
+ self.checkpointer.folder = parent_dir
+
+ # Extract step number from path (verl uses global_step_N format)
+ match = re.search(r"global_step_(\d+)", local_path)
+ if match:
+ step = int(match.group(1))
+ self.checkpointer.load(step=step)
+ else:
+ # Fallback to latest
+ self.checkpointer.load(step=-1)
+
+ torch.distributed.barrier()
+ if self._is_offload_param:
+ for module in self.module:
+ offload_fsdp_model_to_cpu(module)
+
+ if self._is_offload_optimizer:
+ offload_fsdp_optimizer(self.optimizer)
+
+ def get_per_tensor_param(self, **kwargs):
+ for module in self.module:
+ load_fsdp_model_to_gpu(module)
+
+ # Collect state dicts from all model parts
+ params = {}
+ for module in self.module:
+ module_params = get_model_state_dict(module)
+ params.update(module_params)
+
+ if self._is_offload_param:
+ for module in self.module:
+ offload_fsdp_model_to_cpu(module)
+
+ # Convert TorchTitan key names to HuggingFace key names (expected by vLLM)
+ sd_adapter = self.checkpointer.sd_adapter
+ if sd_adapter is not None:
+ params = sd_adapter.to_hf(params)
+
+ # When weight tying is enabled, the sd_adapter skips lm_head.weight during
+ # to_hf() conversion (since it's the same tensor as embed_tokens.weight in
+ # the torchtitan model). But vLLM needs lm_head.weight explicitly, so we
+ # add it back as a reference to embed_tokens.weight.
+ if "model.embed_tokens.weight" in params and "lm_head.weight" not in params:
+ params["lm_head.weight"] = params["model.embed_tokens.weight"]
+
+ device = get_device_id() # used when fsdp2 set cpu_offload_policy
+ # TODO: cast fp32 to bf16 to reduce weight sync overhead, need more fine-grained control, e.g MoE gate
+ per_tensor_param = (
+ (
+ name,
+ param.to(device, non_blocking=True).full_tensor().to(torch.bfloat16, non_blocking=True)
+ if isinstance(param, DTensor)
+ else param,
+ )
+ for name, param in params.items()
+ )
+ # TODO: support Torchtitan PEFT
+ return per_tensor_param, None
+
+
+class EngineEvalModeCtx(BaseEngineCtx):
+ def __init__(self, engine: TorchTitanEngine, **kwargs):
+ super().__init__(engine=engine, mode="eval", **kwargs)
+
+ def __enter__(self):
+ assert isinstance(self.engine, TorchTitanEngine)
+ super().__enter__()
+ for module in self.engine.module:
+ module.eval()
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ assert isinstance(self.engine, TorchTitanEngine)
+
+ # Reshard the root FSDP module
+ if self.engine.engine_config.data_parallel_shard_size > 1:
+ for module in self.engine.module:
+ module.reshard()
+
+ super().__exit__(exc_type, exc_value, traceback)
+
+
+class EngineTrainModeCtx(BaseEngineCtx):
+ def __init__(self, engine: TorchTitanEngine, **kwargs):
+ super().__init__(engine=engine, mode="train", **kwargs)
+
+ def __enter__(self):
+ assert isinstance(self.engine, TorchTitanEngine)
+ super().__enter__()
+ for module in self.engine.module:
+ module.train()
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ assert isinstance(self.engine, TorchTitanEngine)
+ self.engine.optimizer_zero_grad()
+ super().__exit__(exc_type, exc_value, traceback)
+
+
+@EngineRegistry.register(model_type="language_model", backend=["torchtitan"], device=["cuda", "npu"])
+class TorchTitanEngineWithLMHead(TorchTitanEngine):
+ """TorchTitan engine implementation for language models with LM head."""
+
+ def prepare_model_inputs(self, micro_batch: TensorDict):
+ use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
+ pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
+ assert pad_mode == DatasetPadMode.NO_PADDING, f"pad_mode {pad_mode} not supported"
+
+ multi_modal_inputs = extract_multi_modal_inputs(micro_batch.get("multi_modal_inputs", []))
+ input_ids = micro_batch["input_ids"]
+ position_ids = micro_batch["position_ids"]
+ output_args = {}
+
+ if use_remove_padding:
+ input_ids = input_ids.values().unsqueeze(0)
+ if position_ids.dim() == 3:
+ position_ids = position_ids.values().unsqueeze(1)
+ else:
+ position_ids = position_ids.values().unsqueeze(0)
+
+ labels = torch.roll(input_ids, shifts=-1, dims=1)
+ attn_type = self.trainer.model_args.attn_type
+ attention_mask = get_attention_masks(
+ input_batch=input_ids,
+ positions=position_ids,
+ attn_type=attn_type,
+ )
+ else:
+ loss_mask = micro_batch["loss_mask"]
+ pad_token_id = tu.get_non_tensor_data(data=micro_batch, key="pad_token_id", default=0)
+ batch_size = micro_batch.batch_size[0]
+ max_seq_len = max(input_ids.offsets().diff())
+
+ labels = torch.roll(input_ids.values(), shifts=-1, dims=0)
+ input_ids = torch.nested.to_padded_tensor(
+ input_ids, padding=pad_token_id, output_size=(batch_size, max_seq_len)
+ )
+
+ if position_ids.dim() == 3:
+ position_ids = torch.nested.to_padded_tensor(
+ position_ids, padding=0, output_size=(batch_size, 4, max_seq_len)
+ ).transpose(0, 1)
+ else:
+ position_ids = torch.nested.to_padded_tensor(
+ position_ids, padding=0, output_size=(batch_size, max_seq_len)
+ )
+
+ attention_mask_list = [torch.ones_like(t, dtype=torch.int32) for t in loss_mask]
+ attention_mask = torch.nested.as_nested_tensor(attention_mask_list, layout=torch.jagged)
+ attention_mask = torch.nested.to_padded_tensor(
+ attention_mask, padding=0, output_size=(batch_size, max_seq_len)
+ )
+
+ extra_inputs = {
+ "positions": position_ids,
+ }
+ # For arguments, like attention_masks, we have to put them in a separate
+ # dict as extra_inputs are not forwarded to other stages in PP, but
+ # extra_kwargs are.
+ extra_kwargs: dict[str, Any] = {"attention_masks": attention_mask}
+ if self.parallel_dims.cp_enabled:
+ input_ids, labels, extra_kwargs = prepare_context_parallel_input(
+ input_ids,
+ labels,
+ extra_kwargs,
+ self.parallel_dims.get_mesh("cp"),
+ self.trainer.device,
+ self.trainer.job_config.parallelism.context_parallel_load_balancer,
+ )
+
+ # TODO(jessicazhong): multimodal is not yet supported for Torchtitan engine
+ extra_inputs.update(multi_modal_inputs)
+ output_args["labels"] = labels
+ return input_ids, extra_inputs, extra_kwargs, output_args
+
+ def prepare_model_outputs(self, logits, output_args, micro_batch: TensorDict):
+ use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
+ pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
+ assert pad_mode == DatasetPadMode.NO_PADDING, f"pad_mode {pad_mode} not supported"
+
+ temperature = micro_batch["temperature"]
+ calculate_entropy = tu.get_non_tensor_data(data=micro_batch, key="calculate_entropy", default=False)
+ labels = output_args["labels"]
+ model_output = {}
+
+ input_ids = micro_batch["input_ids"]
+ cu_seqlens = input_ids.offsets()
+ if use_remove_padding:
+ labels = labels.squeeze(0)
+ logits_rmpad = logits.squeeze(0)
+ # PyTorch's autograd doesn't allow in-place modification of views when gradients need to flow back
+ logits_rmpad = logits_rmpad / temperature
+
+ inplace_backward = True
+ if calculate_entropy:
+ inplace_backward = False
+ log_probs = logprobs_from_logits(
+ logits=logits_rmpad,
+ labels=labels,
+ inplace_backward=inplace_backward,
+ )
+
+ if calculate_entropy:
+ if not self.engine_config.entropy_checkpointing:
+ entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad)
+ else:
+ entropy_rmpad = torch.utils.checkpoint.checkpoint(self.compute_entropy_from_logits, logits_rmpad)
+
+ log_probs = torch.nested.nested_tensor_from_jagged(log_probs.squeeze(0), cu_seqlens)
+ if calculate_entropy:
+ entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)
+ else:
+ logits.div_(temperature)
+ if calculate_entropy:
+ if not self.engine_config.entropy_checkpointing:
+ entropy = verl_F.entropy_from_logits(logits)
+ else:
+ entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits)
+
+ seq_lengths = cu_seqlens.diff()
+ starts = torch.zeros_like(seq_lengths, dtype=torch.int64)
+ logits = torch.nested.narrow(logits, 1, starts, seq_lengths, layout=torch.jagged)
+ logits_rmpad = torch.cat([t for t in logits.unbind()])
+ log_probs = logprobs_from_logits(logits=logits_rmpad, labels=output_args["labels"])
+ log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)
+ if calculate_entropy:
+ entropy = torch.nested.narrow(entropy, 1, starts, seq_lengths, layout=torch.jagged)
+ entropy_rmpad = torch.cat([t for t in entropy.unbind()])
+ entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)
+
+ model_output["log_probs"] = log_probs
+ if calculate_entropy:
+ model_output["entropy"] = entropy
+
+ return model_output
+
+ def forward_step(self, micro_batch: TensorDict, loss_function, forward_only):
+ device_name = get_device_name()
+ micro_batch = micro_batch.to(get_device_id())
+ input_ids, extra_inputs, extra_kwargs, output_args = self.prepare_model_inputs(micro_batch=micro_batch)
+
+ with torch.autocast(device_type=device_name, dtype=torch.bfloat16):
+ logits = self.model_forward_step(inputs=input_ids, extra_inputs=extra_inputs, extra_kwargs=extra_kwargs)
+
+ model_output = self.prepare_model_outputs(logits=logits, output_args=output_args, micro_batch=micro_batch)
+
+ if loss_function is not None:
+ loss, metrics = loss_function(
+ model_output=model_output, data=micro_batch, dp_group=self.get_data_parallel_group()
+ )
+ else:
+ assert forward_only, "forward_only must be True when loss_function is None"
+ loss = torch.tensor(1.0, device=device_name)
+ metrics = {}
+
+ output = {
+ "model_output": model_output,
+ "loss": loss.detach().item(),
+ "metrics": metrics,
+ }
+
+ return loss, output
diff --git a/verl/workers/engine/torchtitan/utils.py b/verl/workers/engine/torchtitan/utils.py
new file mode 100644
index 00000000000..686fb94e6b2
--- /dev/null
+++ b/verl/workers/engine/torchtitan/utils.py
@@ -0,0 +1,213 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import torch
+import torch.nn as nn
+from torch.distributed._composable.fsdp import FSDPModule
+from torch.nn.attention.flex_attention import _mask_mod_signature, and_masks
+from torchtitan.models.attention import VarlenMetadata, create_attention_mask, get_causal_mask_mod
+from torchtitan.protocols.model import AttentionMasksType
+
+logger = logging.getLogger(__name__)
+
+# Mapping from HuggingFace model_type to torchtitan model name.
+# Torchtitan models not mapped here:
+# - flux: diffusion model, not applicable to verl's RL/SFT workflows
+# - llama3_ft: fault-tolerant variant of llama3, same HF models (mapped via "llama")
+_HF_MODEL_TYPE_TO_TORCHTITAN_NAME = {
+ "qwen2": "qwen3",
+ "qwen3": "qwen3",
+ "qwen2_moe": "qwen3",
+ "qwen3_moe": "qwen3",
+ "llama": "llama3",
+ "llama4": "llama4",
+ "deepseek_v3": "deepseek_v3",
+ "gpt_oss": "gpt_oss",
+}
+
+
+def derive_torchtitan_name_and_flavor(hf_config) -> tuple[str, str]:
+ """Derive torchtitan model name and flavor from a HuggingFace config.
+
+ The name is mapped from ``hf_config.model_type``. The flavor is found by
+ matching architecture parameters (dim, n_layers, vocab_size) against the
+ known flavors registered in the torchtitan TrainSpec.
+
+ Args:
+ hf_config: A HuggingFace AutoConfig object.
+
+ Returns:
+ A ``(name, flavor)`` tuple.
+
+ Raises:
+ ValueError: If model_type is unsupported or no matching flavor is found.
+ """
+ import torchtitan.protocols.train_spec as train_spec_module
+
+ model_type = getattr(hf_config, "model_type", None)
+ if model_type is None:
+ raise ValueError("HuggingFace config does not have 'model_type' field")
+
+ name = _HF_MODEL_TYPE_TO_TORCHTITAN_NAME.get(model_type)
+ if name is None:
+ raise ValueError(
+ f"Cannot derive torchtitan model name from HF model_type '{model_type}'. "
+ f"Supported types: {list(_HF_MODEL_TYPE_TO_TORCHTITAN_NAME.keys())}."
+ )
+
+ train_spec = train_spec_module.get_train_spec(name)
+
+ hidden_size = hf_config.hidden_size
+ num_layers = hf_config.num_hidden_layers
+ vocab_size = hf_config.vocab_size
+
+ for flavor_name, model_args in train_spec.model_args.items():
+ if (
+ getattr(model_args, "dim", None) == hidden_size
+ and getattr(model_args, "n_layers", None) == num_layers
+ and getattr(model_args, "vocab_size", None) == vocab_size
+ ):
+ logger.info(
+ f"Auto-derived torchtitan name='{name}', flavor='{flavor_name}' from HF model_type='{model_type}'"
+ )
+ return name, flavor_name
+
+ raise ValueError(
+ f"No matching torchtitan flavor found for model_type='{model_type}' "
+ f"(hidden_size={hidden_size}, num_hidden_layers={num_layers}, "
+ f"vocab_size={vocab_size}). "
+ f"Available flavors for '{name}': {list(train_spec.model_args.keys())}."
+ )
+
+
+def enable_fsdp_gradient_division(model: nn.Module, dp_size: int) -> None:
+ """
+ Re-enable FSDP's automatic gradient division.
+
+ TorchTitan calls disable_fsdp_gradient_division() which sets gradient_divide_factor=1.0.
+ This re-enables it by setting the factor to the specified dp_size, so gradients are
+ averaged across FSDP ranks. This is needed for verl's loss scaling (loss * dp_size)
+ to work correctly.
+
+ Args:
+ model: The model (or model part) to enable gradient division on.
+ dp_size: The data parallel size to use as the gradient divide factor.
+ """
+
+ for module in model.modules():
+ if isinstance(module, FSDPModule):
+ module.set_gradient_divide_factor(float(dp_size))
+
+
+def get_attention_masks(
+ input_batch: torch.Tensor,
+ positions: torch.Tensor,
+ attn_type: str,
+) -> AttentionMasksType:
+ match attn_type:
+ case "flex":
+ return _get_flex_attention_masks(
+ input_batch,
+ positions,
+ )
+ case "varlen":
+ return _create_varlen_metadata_for_document(
+ input_batch,
+ positions,
+ )
+ case _:
+ raise TypeError("Only varlen and flex attn masks are supported")
+
+
+def _get_document_mask_mod(positions: torch.Tensor) -> _mask_mod_signature:
+ # Detect boundaries from position resets
+ first_dummy_value = positions[:, :1] - 1
+ position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1)
+ sequence_indices = (position_diff != 1).cumsum(-1) # [batch, seq]
+
+ def document_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor:
+ return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx]
+
+ return document_mask
+
+
+def _get_flex_attention_masks(
+ input_batch: torch.Tensor,
+ positions: torch.Tensor,
+) -> AttentionMasksType:
+ mask_mods = [get_causal_mask_mod()]
+ B = input_batch.shape[0]
+ mask_mods.append(_get_document_mask_mod(positions=positions))
+ return create_attention_mask(and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1])
+
+
+def _create_varlen_metadata_for_document(input_batch: torch.Tensor, positions: torch.Tensor) -> VarlenMetadata:
+ """
+ Creates cumulative sequence length indices needed for variable length attention
+
+ Args:
+ input_batch: Input token IDs with shape [batch, seq].
+ positions: Position IDs with shape [batch, seq]. Boundaries detected where
+ position diff != 1 (i.e., position resets).
+
+ Returns:
+ VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len
+ """
+ batch_size, seq_len = input_batch.shape
+ device = input_batch.device
+
+ # Detect boundaries from position resets (where diff != 1)
+ first_dummy_value = positions[:, :1] - 1
+ position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1)
+ # boundary_mask[b, i] is True if position i starts a new document
+ boundary_mask = position_diff != 1 # [batch, seq]
+ boundary_mask[:, 0] = True
+
+ cu_seqlens_list, all_seq_lengths = [], []
+ offset = 0
+
+ for b in range(batch_size):
+ # Find positions where new documents start
+ boundary_positions = boundary_mask[b].nonzero(as_tuple=True)[0].to(torch.int32)
+ sample_cu_seqlens = torch.cat(
+ [
+ boundary_positions,
+ torch.tensor([seq_len], dtype=torch.int32, device=device),
+ ]
+ )
+ sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)
+
+ seq_lengths = torch.diff(sample_cu_seqlens)
+ all_seq_lengths.append(seq_lengths)
+
+ cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset
+ cu_seqlens_list.append(cu_seqlens_adjusted)
+
+ offset += seq_len
+
+ packed_cu_seqlens = torch.cat(cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)])
+
+ max_seqlen = 0
+ if len(all_seq_lengths) > 0:
+ all_seq_lengths = torch.cat(all_seq_lengths)
+ # device to host sync but only done once per model forward
+ max_seqlen = all_seq_lengths.max().item()
+
+ return VarlenMetadata(
+ cu_seq_q=packed_cu_seqlens,
+ cu_seq_k=packed_cu_seqlens,
+ max_q=max_seqlen,
+ max_k=max_seqlen,
+ )
diff --git a/verl/workers/engine/utils.py b/verl/workers/engine/utils.py
index 0484b0d2a0c..b3261dafde3 100644
--- a/verl/workers/engine/utils.py
+++ b/verl/workers/engine/utils.py
@@ -70,7 +70,10 @@ def prepare_micro_batches(
use_dynamic_bsz = tu.get_non_tensor_data(data=data, key="use_dynamic_bsz", default=True)
sp_size = tu.get_non_tensor_data(data=data, key="sp_size", default=1)
+ force_group_size = tu.get_non_tensor_data(data=data, key="force_group_size", default=1)
+
if use_dynamic_bsz:
+ assert force_group_size == 1, "force_group_size is not supported when use_dynamic_bsz is True"
assert "max_token_len_per_gpu" in data.keys(), "max_token_len_per_gpu must be set when use_dynamic_bsz is True"
max_token_len_per_gpu = data["max_token_len_per_gpu"]
max_token_len = max_token_len_per_gpu * sp_size
@@ -84,8 +87,12 @@ def prepare_micro_batches(
use_dynamic_bsz_balance=use_dynamic_bsz_balance,
)
else:
+ total_data_size = len(data)
micro_batch_size_per_gpu = data["micro_batch_size_per_gpu"]
- micro_batches = tu.chunk_tensordict(data, len(data) // micro_batch_size_per_gpu)
+ assert total_data_size % (force_group_size * micro_batch_size_per_gpu) == 0, (
+ "data size must be divisible by force_group_size * micro_batch_size_per_gpu"
+ )
+ micro_batches = tu.chunk_tensordict(data, total_data_size // (micro_batch_size_per_gpu * force_group_size))
batch_idx_list = None
return micro_batches, batch_idx_list
diff --git a/verl/workers/engine/veomni/transformer_impl.py b/verl/workers/engine/veomni/transformer_impl.py
index 5aaf41f9669..6d0e50806dc 100644
--- a/verl/workers/engine/veomni/transformer_impl.py
+++ b/verl/workers/engine/veomni/transformer_impl.py
@@ -35,8 +35,11 @@
from verl.utils.fsdp_utils import fsdp_version
from verl.utils.model import convert_weight_keys
from verl.utils.profiler import log_gpu_memory_usage
+from verl.utils.ulysses import (
+ get_ulysses_sequence_parallel_group,
+ set_ulysses_sequence_parallel_group,
+)
from verl.workers.config import HFModelConfig, VeOmniEngineConfig, VeOmniOptimizerConfig
-from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from ..base import BaseEngineCtx, EngineRegistry
from ..fsdp.transformer_impl import FSDPEngine, FSDPEngineWithLMHead
@@ -79,14 +82,27 @@ def __init__(
self.data_parallel_mode = "fsdp2"
self.rank = dist.get_rank()
+ fsdp_size = self.engine_config.fsdp_size
+ world_size = dist.get_world_size()
+ dp_size = world_size // self.engine_config.ulysses_parallel_size
+
+ if fsdp_size < 0 or fsdp_size >= dp_size:
+ data_parallel_replicate_size = 1
+ data_parallel_shard_size = dp_size
+ else:
+ if dp_size % fsdp_size != 0:
+ raise ValueError(
+ f"Data parallel size ({dp_size}) must be divisible by fsdp_size ({fsdp_size}). "
+ "Please adjust your parallel configuration."
+ )
+ data_parallel_replicate_size = dp_size // fsdp_size
+ data_parallel_shard_size = fsdp_size
+
parallel_state.init_parallel_state(
- dp_size=self.engine_config.data_parallel_size,
- dp_replicate_size=self.engine_config.data_parallel_replicate_size,
- dp_shard_size=self.engine_config.data_parallel_shard_size,
- tp_size=self.engine_config.tensor_parallel_size,
+ dp_size=dp_size,
+ dp_replicate_size=data_parallel_replicate_size,
+ dp_shard_size=data_parallel_shard_size,
ep_size=self.engine_config.expert_parallel_size,
- pp_size=self.engine_config.pipeline_parallel_size,
- cp_size=self.engine_config.context_parallel_size,
ulysses_size=self.engine_config.ulysses_parallel_size,
dp_mode=self.data_parallel_mode,
)
@@ -104,9 +120,9 @@ def __init__(
self.ulysses_sequence_parallel_size = self.engine_config.ulysses_parallel_size
if self.use_ulysses_sp:
- self.ulysses_sharding_manager = FSDPUlyssesShardingManager(parallel_state.get_parallel_state().device_mesh)
+ self.ulysses_parallel_group = parallel_state.get_parallel_state().device_mesh["sp"].get_group()
else:
- self.ulysses_sharding_manager = FSDPUlyssesShardingManager(None)
+ self.ulysses_parallel_group = None
if self.engine_config.entropy_from_logits_with_chunking:
entropy_from_logits = verl_F.entropy_from_logits_with_chunking
@@ -438,12 +454,13 @@ def __init__(self, engine: VeOmniEngine, **kwargs):
def __enter__(self):
assert isinstance(self.engine, VeOmniEngine)
super().__enter__()
- self.engine.ulysses_sharding_manager.__enter__()
+ self.prev_sp_group = get_ulysses_sequence_parallel_group()
+ set_ulysses_sequence_parallel_group(self.engine.ulysses_parallel_group)
self.engine.module.train()
def __exit__(self, exc_type, exc_value, traceback):
assert isinstance(self.engine, VeOmniEngine)
- self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)
+ set_ulysses_sequence_parallel_group(self.prev_sp_group)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
@@ -463,14 +480,15 @@ def __init__(self, engine: VeOmniEngine, **kwargs):
def __enter__(self):
assert isinstance(self.engine, VeOmniEngine)
super().__enter__()
- self.engine.ulysses_sharding_manager.__enter__()
+ self.prev_sp_group = get_ulysses_sequence_parallel_group()
+ set_ulysses_sequence_parallel_group(self.engine.ulysses_parallel_group)
# TODO: Switch to eval mode after Integrating the CI environment
# VeOmni (ref: https://github.com/ByteDance-Seed/VeOmni/pull/421)
self.engine.module.train()
def __exit__(self, exc_type, exc_value, traceback):
assert isinstance(self.engine, VeOmniEngine)
- self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)
+ set_ulysses_sequence_parallel_group(self.prev_sp_group)
self.engine.optimizer_zero_grad()
super().__exit__(exc_type, exc_value, traceback)
@@ -492,6 +510,19 @@ class OmniSequenceShardCollator:
metadata={"help": "features to slice sequence dimension."},
)
+ # features to padding sequence dimension
+ padding_features: dict[str, int] = field(
+ default_factory=lambda: {
+ "pixel_values": 0,
+ },
+ metadata={"help": "features to padding sequence dimension."},
+ )
+
+ # padding scale for padding features
+ padding_scale: dict[str, int] = field(
+ default_factory=lambda: {"pixel_values": 4}, metadata={"help": "padding scale for padding features."}
+ )
+
def __post_init__(self):
self.sp_size = parallel_state.get_parallel_state().sp_size
self.sp_rank = parallel_state.get_parallel_state().sp_rank
@@ -501,7 +532,35 @@ def sp_slice(self, feature: torch.Tensor, dim: int = -1) -> dict[str, "torch.Ten
sp_chunk_size = (seq_length + self.sp_size - 1) // self.sp_size
return feature.narrow(dim, self.sp_rank * sp_chunk_size, sp_chunk_size)
+ def sp_padding(
+ self, tensor: "torch.Tensor", dim: int = -1, pad_value: int = 0, pad_scale: int = 1
+ ) -> "torch.Tensor":
+ """
+ Pads a tensor with pad_length to aligns tensor with sp size.
+ """
+ seq_length = tensor.size(dim)
+ scale_sp_size = self.sp_size * pad_scale
+
+ sp_chunk_size = (seq_length + scale_sp_size - 1) // scale_sp_size
+ pad_size = sp_chunk_size * scale_sp_size - seq_length
+ if pad_size == 0:
+ return tensor
+
+ pad_shape = list(tensor.shape)
+ pad_shape[dim] = pad_size
+ pad = torch.full(pad_shape, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device)
+ return torch.cat((tensor, pad), dim=dim)
+
def __call__(self, batch: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
+ for key in batch.keys():
+ if key in self.padding_features.keys():
+ batch[key] = self.sp_padding(
+ batch[key],
+ dim=self.sp_slice_features.get(key, -1),
+ pad_value=self.padding_features[key],
+ pad_scale=self.padding_scale.get(key, 1),
+ )
+
# sp slice
for key in batch.keys():
if key in self.sp_slice_features.keys():
diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py
index 6f8029600ea..3d129d1fe2f 100644
--- a/verl/workers/engine_workers.py
+++ b/verl/workers/engine_workers.py
@@ -179,17 +179,20 @@ def _postprocess_output(self, output, *, global_token_num, delta_time, forward_o
# Here each metric in metrics can be a list (micro-batch metrics) or a singleton
# we should always sum the loss of each micro-batch as we scale by global_bsz/global_token
loss = torch.sum(torch.tensor(output.pop("loss"), device=self.device_name))
- torch.distributed.all_reduce(
- loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group()
- )
+ dp_group = self.engine.get_data_parallel_group()
+ if dp_group is not None:
+ torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG, group=dp_group)
loss = loss.item()
# For grad_norm, we do not perform all reduce because it is already been done when clipping grad
grad_norm = metrics.pop("grad_norm", None)
lr = metrics.pop("lr", None)
- # For other metrics, we perform all gather in dp group
- final_metrics = allgather_dict_into_dict(data=metrics, group=self.engine.get_data_parallel_group())
+ # For other metrics, we perform all gather in dp group (only if DP > 1)
+ if dp_group is not None:
+ final_metrics = allgather_dict_into_dict(data=metrics, group=dp_group)
+ else:
+ final_metrics = metrics
final_metrics["loss"] = loss
if grad_norm is not None:
final_metrics["grad_norm"] = grad_norm
@@ -580,6 +583,9 @@ def init_model(self):
backend, is_master=(torch.distributed.get_rank() == 0), bucket_size=bucket_size, **engine_kwargs
)
+ # Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo
+ aggressive_empty_cache(force_sync=True)
+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref"))
@DistProfiler.annotate(color="olive", role="ref_compute_log_prob")
@_with_routing_replay_flag(enabled=False)
@@ -621,11 +627,10 @@ async def update_weights(self):
- after update_weights: rollout should be in wake_up mode.
2. For async training with disaggregated trainer and rollout, send_weights only by checkpoint engine.
"""
- assert self.checkpoint_engine is not None
# 0. send_weights only for async training with disaggregated trainer and rollout
if self.config.rollout.checkpoint_engine.backend != "naive":
- per_tensor_param, _ = self.engine.get_per_tensor_param()
+ per_tensor_param, _ = self.actor.engine.get_per_tensor_param()
await self.checkpoint_engine.send_weights(per_tensor_param)
return
diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py
index 1425091432e..24b46e13c1b 100644
--- a/verl/workers/fsdp_workers.py
+++ b/verl/workers/fsdp_workers.py
@@ -28,6 +28,7 @@
import torch.distributed as dist
from codetiming import Timer
from omegaconf import DictConfig, OmegaConf, open_dict
+from omegaconf.errors import ConfigAttributeError
from peft import LoraConfig, TaskType, get_peft_model
from safetensors.torch import save_file
from torch.distributed.device_mesh import init_device_mesh
@@ -81,6 +82,9 @@
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
from verl.utils.py_functional import convert_to_regular_types
+
+# QAT support
+from verl.utils.qat import apply_qat, enable_qat_fuse
from verl.utils.ray_utils import get_event_loop
from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig
from verl.workers.config.optimizer import build_optimizer
@@ -275,6 +279,52 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size
+ def _init_qat_config(self):
+ """Initialize QAT configuration from actor.qat."""
+ try:
+ self.qat_config = self.config.actor.qat
+ self._qat_enabled = self.qat_config.enable
+ if self._qat_enabled:
+ logger.info(
+ f"QAT enabled: mode={self.qat_config.mode}, config_path={self.qat_config.quantization_config_path}"
+ )
+ except (AttributeError, KeyError, ConfigAttributeError):
+ # QAT config not provided, disable QAT
+ self._qat_enabled = False
+ self.qat_config = None
+
+ def _restore_w4a4_input_scales(self, model, model_path):
+ """Restore input_global_scale and input_amax from checkpoint for W4A4 mode."""
+ import glob
+
+ from safetensors import safe_open
+
+ safetensor_files = glob.glob(f"{model_path}/model*.safetensors")
+ loaded_count = 0
+
+ for sf_path in safetensor_files:
+ with safe_open(sf_path, framework="pt") as f:
+ for key in f.keys():
+ if "input_global_scale" in key:
+ module_path = key.replace(".input_global_scale", "")
+ amax_key = f"{module_path}.input_amax"
+
+ module = model
+ for part in module_path.split("."):
+ module = getattr(module, part)
+
+ scale_val = f.get_tensor(key)
+ val = scale_val.item() if scale_val.numel() == 1 else scale_val.max().item()
+ module.input_global_scale.fill_(val)
+
+ amax_val = f.get_tensor(amax_key)
+ amax = amax_val.item() if amax_val.numel() == 1 else amax_val.max().item()
+ module.input_amax.fill_(amax)
+ loaded_count += 1
+
+ if self.rank == 0:
+ logger.info(f"[W4A4] Loaded {loaded_count} input scales from checkpoint")
+
def _build_model_optimizer(
self,
model_path,
@@ -485,6 +535,13 @@ def _build_model_optimizer(
if self.rank == 0:
print("[actor model] No vision tower found.")
+ # Apply QAT before FSDP wrapping (actor only)
+ if role == "actor" and self._qat_enabled:
+ actor_module = apply_qat(actor_module, self.qat_config)
+ enable_qat_fuse(actor_module)
+ if self.qat_config.mode == "w4a4":
+ self._restore_w4a4_input_scales(actor_module, self.config.model.path)
+
torch.distributed.barrier()
if self.rank == 0:
@@ -505,6 +562,9 @@ def _build_model_optimizer(
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
+ # Store param_dtype for QAT quantizer
+ self._param_dtype = param_dtype
+
auto_wrap_policy = get_fsdp_wrap_policy(
module=actor_module,
config=fsdp_config.get("wrap_policy", None),
@@ -740,6 +800,23 @@ async def rollout_mode(self):
for name, param in params.items()
)
+ # QAT: quantize weights before sending to vLLM
+ if self._qat_enabled:
+ from verl.utils.qat.quantizer import QATQuantizer
+
+ quantizer = QATQuantizer(
+ mode=self.qat_config.mode,
+ group_size=self.qat_config.group_size,
+ ignore_patterns=self.qat_config.ignore_patterns,
+ device=torch.device(get_device_id()),
+ param_dtype=self._param_dtype,
+ )
+ per_tensor_param = quantizer.quantize_with_fusion(
+ per_tensor_param,
+ target_device=torch.device("cpu"),
+ )
+ aggressive_empty_cache(force_sync=True)
+
if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["weights"])
log_gpu_memory_usage("After resume weights", logger=logger)
@@ -774,6 +851,9 @@ def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
+ # Initialize QAT config before _build_model_optimizer
+ self._init_qat_config()
+
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
use_remove_padding = self.config.model.get("use_remove_padding", False)
use_shm = self.config.model.get("use_shm", False)
@@ -901,6 +981,9 @@ def init_model(self):
checkpoint_config=checkpoint_contents,
)
+ # Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo
+ aggressive_empty_cache(force_sync=True)
+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@DistProfiler.annotate(color="red", role="actor_update")
def update_actor(self, data: DataProto):
diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py
index 3baf4020810..f78dcde56e2 100644
--- a/verl/workers/megatron_workers.py
+++ b/verl/workers/megatron_workers.py
@@ -155,7 +155,6 @@ def _init_hf_config_and_tf_config(
if enable_mtp:
assert hf_config.num_nextn_predict_layers > 0, "MTP requires at least one nextn_predict_layer"
assert megatron_config.use_mbridge, "MTP requires use_mbridge to be True"
- assert megatron_config.vanilla_mbridge, "MTP requires vanilla_mbridge to be True"
override_transformer_config["mtp_loss_scaling_factor"] = self.config.model.mtp.mtp_loss_scaling_factor
else:
if hasattr(hf_config, "num_nextn_predict_layers"):
@@ -199,6 +198,10 @@ def _init_hf_config_and_tf_config(
# In case of invalid overrides, we need to make sure some critical params are set correctly
provider.params_dtype = dtype
+ # Ensure dtype settings propagate to Megatron-Bridge/TE
+ provider.fp16 = fp16
+ provider.bf16 = bf16
+
# Pass distributed info
provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size
provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size
@@ -670,7 +673,8 @@ def init_model(self):
if not self.config.actor.megatron.use_mbridge:
self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
- get_torch_device().empty_cache()
+ # Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo
+ aggressive_empty_cache(force_sync=True)
log_gpu_memory_usage("After init_model finish", logger=logger)
async def rollout_mode(self):
diff --git a/verl/workers/rollout/base.py b/verl/workers/rollout/base.py
index 31d5b9736b7..c8038606f1f 100644
--- a/verl/workers/rollout/base.py
+++ b/verl/workers/rollout/base.py
@@ -34,6 +34,8 @@ def __init__(
config: RolloutConfig,
model_config: HFModelConfig,
device_mesh: DeviceMesh,
+ *args,
+ **kwargs,
):
self.config = omega_conf_to_dataclass(config)
self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)
diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py
index bf83ac7d05f..ed327c25293 100644
--- a/verl/workers/rollout/replica.py
+++ b/verl/workers/rollout/replica.py
@@ -18,6 +18,7 @@
from enum import Enum
from typing import Any, Callable, Optional
+import ray
from omegaconf import DictConfig
from pydantic import BaseModel
from ray.actor import ActorHandle
@@ -30,6 +31,11 @@
logger = logging.getLogger(__file__)
+# Max number of concurrent calls to the methods of Rollout,
+# excluding calls to generate method.
+CONTROL_METHOD_CONCURRENCY = 16
+
+
class TokenOutput(BaseModel):
token_ids: list[int]
"""response token ids"""
@@ -91,7 +97,7 @@ def __init__(
is_reward_model: bool = False,
) -> None:
self.replica_rank = replica_rank
- self.config = omega_conf_to_dataclass(config)
+ self.config: RolloutConfig = omega_conf_to_dataclass(config)
self.model_config: HFModelConfig = model_config
self.world_size = (
@@ -200,10 +206,18 @@ async def init_standalone(self):
self.workers = worker_group.workers
await self.launch_servers()
- @abstractmethod
def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:
"""Get rollout worker actor class for colocated and standalone mode."""
- raise NotImplementedError
+ from verl.checkpoint_engine.base import CheckpointEngineWorker
+
+ rollout_worker_actor_cls = ray.remote(CheckpointEngineWorker)
+
+ return RayClassWithInitArgs(
+ cls=rollout_worker_actor_cls,
+ rollout_config=self.config,
+ model_config=self.model_config,
+ replica_rank=self.replica_rank,
+ )
@abstractmethod
async def launch_servers(self):
@@ -220,6 +234,12 @@ def server_handle(self) -> ActorHandle:
"""Get rollout server handle for Token-in-token-out generation."""
return self._server_handle
+ @property
+ def max_concurrency(self) -> int:
+ # 1000 is Ray's default max_concurrency for async execution.
+ # Add some margin to account for control method call.
+ return max(1000, self.config.max_num_seqs + CONTROL_METHOD_CONCURRENCY)
+
def rollout_worker_use_gpu(self) -> bool:
return True
diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py
index ab8ee461dea..964fe97df1f 100644
--- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py
+++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py
@@ -39,15 +39,14 @@
)
from sglang.srt.managers.tokenizer_manager import ServerStatus
-from verl.single_controller.ray import RayClassWithInitArgs
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.device import get_visible_devices_keyword
from verl.utils.net_utils import get_free_port, is_valid_ipv6_address
from verl.utils.profiler import DistProfiler, build_sglang_profiler_args
from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
-from verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter, _set_envs_and_config
-from verl.workers.rollout.utils import get_max_position_embeddings, run_unvicorn
+from verl.workers.rollout.sglang_rollout.sglang_rollout import _set_envs_and_config
+from verl.workers.rollout.utils import get_max_position_embeddings, run_uvicorn
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
@@ -123,17 +122,19 @@ def __init__(
profiler_config = None
self.profiler_controller = DistProfiler(self.replica_rank, config=profiler_config, tool_config=tool_config)
- # used for NCCL process group
- if self.node_rank == 0:
+ # For multi-node, we need dist_init_addr so nodes can coordinate NCCL init.
+ # For single-node, let SGLang handle port selection internally via nccl_port,
+ # which also avoids port conflicts.
+ self._master_address = None
+ self._master_port = None
+ self._master_sock = None
+ if self.nnodes > 1 and self.node_rank == 0:
self._master_address = self._server_address
- self._master_port, self._master_sock = get_free_port(self._server_address)
+ self._master_port, self._master_sock = get_free_port(self._server_address, with_alive_sock=True)
logger.info(
f"SGLangHttpServer, replica_rank: {self.replica_rank}, "
f"master address: {self._master_address}, port: {self._master_port}"
)
- else:
- self._master_address = None
- self._master_port = None
def get_master_address(self):
"""Get master address and port for init NCCL process group."""
@@ -145,10 +146,13 @@ def get_server_address(self):
return self._server_address, self._server_port
async def launch_server(self, master_address: str = None, master_port: int = None):
- if self.node_rank != 0:
- assert master_address and master_port, "non-master node should provide master address and port"
- self._master_address = master_address
- self._master_port = master_port
+ if self.nnodes > 1:
+ if self.node_rank != 0:
+ assert master_address and master_port, "non-master node should provide master address and port"
+ self._master_address = master_address
+ self._master_port = master_port
+ else:
+ self._master_sock.close()
engine_kwargs = self.config.get("engine_kwargs", {}).get("sglang", {}) or {}
attention_backend = engine_kwargs.pop("attention_backend", None)
@@ -167,11 +171,6 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS)
else:
raise ValueError(f"Currently only support fp8 quantization, got: {quantization}")
- dist_init_addr = (
- f"[{self._master_address}]:{self._master_port}"
- if is_valid_ipv6_address(self._master_address)
- else f"{self._master_address}:{self._master_port}"
- )
infer_tp = self.config.tensor_model_parallel_size * self.config.data_parallel_size
args = {
"model_path": self.model_config.local_path,
@@ -186,7 +185,6 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
"ep_size": self.config.expert_parallel_size,
"node_rank": self.node_rank,
"load_format": self.config.load_format,
- "dist_init_addr": dist_init_addr,
"nnodes": self.nnodes,
"trust_remote_code": self.model_config.trust_remote_code,
"max_running_requests": self.config.get("max_num_seqs", None),
@@ -202,6 +200,16 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
**engine_kwargs,
}
+ # Only set dist_init_addr for multi-node; for single-node, let SGLang
+ # handle port selection internally via nccl_port to avoid conflicts.
+ if self.nnodes > 1:
+ dist_init_addr = (
+ f"[{self._master_address}]:{self._master_port}"
+ if is_valid_ipv6_address(self._master_address)
+ else f"{self._master_address}:{self._master_port}"
+ )
+ args["dist_init_addr"] = dist_init_addr
+
if self.config.prometheus.enable:
if self.config.prometheus.served_model_name:
# Extract model name from path if it's a full path
@@ -278,7 +286,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
add_prometheus_middleware(app)
- self._server_port, self._server_task = await run_unvicorn(app, server_args, self._server_address)
+ self._server_port, self._server_task = await run_uvicorn(app, server_args, self._server_address)
self.tokenizer_manager.server_status = ServerStatus.Up
async def wake_up(self):
@@ -413,9 +421,6 @@ async def stop_profile(self):
await self.tokenizer_manager.stop_profile()
-_rollout_worker_actor_cls = ray.remote(ServerAdapter)
-
-
class SGLangReplica(RolloutReplica):
def __init__(
self,
@@ -428,16 +433,6 @@ def __init__(
super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model)
self.server_class = ray.remote(SGLangHttpServer)
- def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:
- """Get rollout worker actor class for colocated and standalone mode."""
- worker_dict_cls = RayClassWithInitArgs(
- cls=_rollout_worker_actor_cls,
- config=self.config,
- model_config=self.model_config,
- device_mesh=None,
- )
- return worker_dict_cls
-
async def launch_servers(self):
"""Launch http server in each node."""
assert len(self.workers) == self.world_size, (
@@ -496,6 +491,7 @@ async def launch_servers(self):
),
runtime_env={"env_vars": {f"RAY_EXPERIMENTAL_NOSET_{visible_devices_keyword}": "1"}},
name=name,
+ max_concurrency=self.max_concurrency,
).remote(
config=self.config,
model_config=self.model_config,
@@ -510,7 +506,9 @@ async def launch_servers(self):
self.servers.append(server)
# launch http server in each node
- master_address, master_port = await self.servers[0].get_master_address.remote()
+ master_address, master_port = None, None
+ if self.nnodes > 1:
+ master_address, master_port = await self.servers[0].get_master_address.remote()
await asyncio.gather(
*[
server.launch_server.remote(master_address=master_address, master_port=master_port)
diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py
index 2be15fc5b05..3048ab60148 100644
--- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py
+++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py
@@ -98,6 +98,7 @@ def __init__(
config: RolloutConfig,
model_config: HFModelConfig,
device_mesh: DeviceMesh,
+ replica_rank: int = -1,
):
if config.get("quantization", None) == "fp8":
import sglang
@@ -120,7 +121,10 @@ def __init__(
rank = int(os.environ["RANK"])
local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"])
rollout_world_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size
- self.replica_rank = rank // rollout_world_size
+ if replica_rank == -1:
+ self.replica_rank = rank // rollout_world_size
+ else:
+ self.replica_rank = replica_rank
self.rollout_rank = rank % rollout_world_size
self.node_rank = self.rollout_rank // local_world_size
self.local_rank = self.rollout_rank % local_world_size
diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py
index 97e255ad68c..195e798a6cb 100644
--- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py
+++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py
@@ -23,13 +23,13 @@
from ray.util import placement_group_table
from ray.util.placement_group import PlacementGroup
-from verl.single_controller.ray import RayClassWithInitArgs, SubRayResourcePool
+from verl.single_controller.ray import SubRayResourcePool
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.net_utils import is_valid_ipv6_address
from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
from verl.workers.rollout.trtllm_rollout.trtllm_rollout import ServerAdapter
-from verl.workers.rollout.utils import get_max_position_embeddings, run_unvicorn
+from verl.workers.rollout.utils import get_max_position_embeddings, run_uvicorn
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
@@ -110,7 +110,7 @@ def get_server_address(self):
async def launch_server(self):
from tensorrt_llm import AsyncLLM
- from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig
+ from tensorrt_llm.llmapi import CapacitySchedulerPolicy, CudaGraphConfig, KvCacheConfig, SchedulerConfig
from tensorrt_llm.serve import OpenAIServer
assert self.config.pipeline_model_parallel_size == 1, "pipeline_model_parallel_size > 1 is not supported yet"
@@ -162,7 +162,10 @@ async def launch_server(self):
enable_padding=True,
batch_sizes=self.config.cudagraph_capture_sizes,
max_batch_size=0 if self.config.cudagraph_capture_sizes else self.config.max_num_seqs,
- )
+ ),
+ "scheduler_config": SchedulerConfig(
+ capacity_scheduler_policy=CapacitySchedulerPolicy.MAX_UTILIZATION,
+ ),
}
)
@@ -202,7 +205,7 @@ async def launch_server(self):
)
app = trtllm_server.app
- self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address)
+ self._server_port, self._server_task = await run_uvicorn(app, None, self._server_address)
@resume_on_abort
async def generate(
@@ -280,9 +283,6 @@ async def report_device_ids(self) -> list[str]:
)
-_rollout_worker_actor_cls = ray.remote(ServerAdapter)
-
-
class TRTLLMReplica(RolloutReplica):
def __init__(
self,
@@ -295,17 +295,6 @@ def __init__(
super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model)
self.node_ip = ray.util.get_node_ip_address().strip("[]")
- def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:
- """Get rollout worker actor class for colocated and standalone mode."""
- worker_dict_cls = RayClassWithInitArgs(
- cls=_rollout_worker_actor_cls,
- config=self.config,
- model_config=self.model_config,
- device_mesh=None,
- replica_rank=self.replica_rank,
- )
- return worker_dict_cls
-
def rollout_worker_use_gpu(self) -> bool:
return False
@@ -392,6 +381,7 @@ async def launch_servers(self):
),
runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}},
name=name,
+ max_concurrency=self.max_concurrency,
).remote(
config=self.config,
model_config=self.model_config,
diff --git a/verl/workers/rollout/utils.py b/verl/workers/rollout/utils.py
index 246ed3896b1..69c688dfa24 100644
--- a/verl/workers/rollout/utils.py
+++ b/verl/workers/rollout/utils.py
@@ -13,13 +13,10 @@
# limitations under the License.
import asyncio
import logging
-import os
import uvicorn
from fastapi import FastAPI
-from verl.utils.net_utils import get_free_port
-
logger = logging.getLogger(__file__)
@@ -35,25 +32,42 @@ def get_max_position_embeddings(hf_config) -> int:
return int(max_len)
-async def run_unvicorn(app: FastAPI, server_args, server_address, max_retries=5) -> tuple[int, asyncio.Task]:
- server_port, server_task = None, None
+class _UvicornServerAutoPort(uvicorn.Server):
+ """Uvicorn Server that reports the system-assigned port when port=0."""
+
+ def __init__(self, config: uvicorn.Config) -> None:
+ super().__init__(config)
+ self.actual_port: int | None = None
+ self._startup_done: asyncio.Event = asyncio.Event()
- for i in range(max_retries):
+ async def startup(self, sockets=None) -> None:
try:
- server_port, sock = get_free_port(server_address)
- app.server_args = server_args
- config = uvicorn.Config(app, host=server_address, port=server_port, log_level="warning")
- server = uvicorn.Server(config)
- server.should_exit = True
- await server.serve()
- server_task = asyncio.create_task(server.main_loop())
- break
- except (OSError, SystemExit) as e:
- logger.error(f"Failed to start HTTP server on port {server_port} at try {i}, error: {e}")
- else:
- logger.error(f"Failed to start HTTP server after {max_retries} retries, exiting...")
- os._exit(-1)
+ await super().startup(sockets=sockets)
+ if self.servers and self.config.port == 0:
+ sock = self.servers[0].sockets[0]
+ self.actual_port = sock.getsockname()[1]
+ else:
+ self.actual_port = self.config.port
+ finally:
+ self._startup_done.set()
+
+ async def get_port(self) -> int | None:
+ await self._startup_done.wait()
+ return self.actual_port
+
+
+async def run_uvicorn(app: FastAPI, server_args, server_address) -> tuple[int, asyncio.Task]:
+ app.server_args = server_args
+ config = uvicorn.Config(app, host=server_address, port=0, log_level="warning")
+ server = _UvicornServerAutoPort(config)
+ server_task = asyncio.create_task(server.serve())
+ server_port = await server.get_port()
+ if server_port is None:
+ # server.startup() failed. await the task to re-raise exception from server.serve()
+ await server_task
+ # Fails on unexpected situation.
+ raise RuntimeError("Unexpected: HTTP server started without reporting listened port")
logger.info(f"HTTP server started on port {server_port}")
return server_port, server_task
diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py
index d2369dc2e89..7fa3b1dd67c 100644
--- a/verl/workers/rollout/vllm_rollout/utils.py
+++ b/verl/workers/rollout/vllm_rollout/utils.py
@@ -163,6 +163,15 @@ def __new__(cls, **kwargs):
# 2. patch online fp8 quant
if os.environ.get("VERL_VLLM_FP8_QUANT_ENABLED", "0") == "1":
apply_vllm_fp8_patches()
+ # 3. patch QAT (compressed-tensors NVFP4) for dynamic weight loading
+ vllm_config = kwargs.get("vllm_config")
+ quant_config = getattr(vllm_config, "quant_config", None) if vllm_config else None
+ _is_qat_model = getattr(quant_config, "quant_format", None) == "nvfp4-pack-quantized"
+ if _is_qat_model:
+ from verl.utils.qat import apply_qat_patches
+
+ apply_qat_patches()
+ logger.info("Applied QAT patches in vLLM worker subprocess")
# TODO: For ascend NPU, when the corresponding vllm-ascend version is upgraded to v0.13.0,
# please remove the VLLM_ASCEND_REQUIRED_ENV_VARS variable replacement action.
@@ -172,7 +181,9 @@ def __new__(cls, **kwargs):
if k not in os.environ:
os.environ[k] = VLLM_ASCEND_REQUIRED_ENV_VARS[k]
- return super().__new__(cls)
+ instance = super().__new__(cls)
+ instance._is_qat_model = _is_qat_model
+ return instance
def monkey_patch_model(self, vocab_size: int):
# patch compute_logits to avoid sampling OOV token
@@ -214,8 +225,14 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
self.model_runner.vllm_config
)
- # Re-apply here because async IPC weight sync can happen long after init and lose MoE weight_loader attrs.
- if use_standard_weight_load:
+ if self._is_qat_model:
+ # QAT: Prepare for weight loading BEFORE receiving any buckets
+ from verl.utils.qat import prepare_qat_for_load_weights
+
+ prepare_qat_for_load_weights(self.model_runner.model, device=self.device)
+ logger.info("QAT: prepare_qat_for_load_weights completed")
+ elif use_standard_weight_load:
+ # Re-apply here because async IPC weight sync can happen long after init and lose MoE weight_loader attrs.
patch_vllm_moe_model_weight_loader(self.model_runner.model)
# receive bucket and update weights
@@ -241,7 +258,13 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
if metadata["is_last"]:
break
- if use_standard_weight_load:
+ if self._is_qat_model:
+ # QAT: call process_weights_after_loading AFTER all buckets are received
+ from verl.utils.qat import manual_process_weights_after_loading
+
+ manual_process_weights_after_loading(self.model_runner.model)
+ logger.info("QAT: process_weights_after_loading completed")
+ elif use_standard_weight_load:
# Some post-load transforms are non-idempotent; run once after all buckets.
from vllm.model_executor.model_loader.utils import process_weights_after_loading
@@ -252,6 +275,7 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
# clean up
socket.close()
del buffer
+ gc.collect()
if shm is not None:
shm.close()
del shm
diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py
index cf5ab342888..9c2bbad47cf 100644
--- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py
+++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py
@@ -35,7 +35,6 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
-from verl.single_controller.ray import RayClassWithInitArgs
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.device import get_resource_name, get_visible_devices_keyword
from verl.utils.net_utils import get_free_port, is_valid_ipv6_address
@@ -43,8 +42,7 @@
from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches
from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
-from verl.workers.rollout.utils import get_max_position_embeddings, run_unvicorn
-from verl.workers.rollout.vllm_rollout import ServerAdapter
+from verl.workers.rollout.utils import get_max_position_embeddings, run_uvicorn
from verl.workers.rollout.vllm_rollout.utils import (
VLLM_LORA_INT_ID,
VLLM_LORA_NAME,
@@ -153,10 +151,10 @@ def __init__(
if self.node_rank == 0:
self._master_address = self._server_address
# used for torch.distributed.init_process_group
- self._master_port, self._master_sock = get_free_port(self._server_address)
+ self._master_port, self._master_sock = get_free_port(self._server_address, with_alive_sock=True)
# used for data parallel: --data-parallel-address, --data-parallel-rpc-port
- self._dp_rpc_port, self._dp_rpc_sock = get_free_port(self._server_address)
- self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address)
+ self._dp_rpc_port, self._dp_rpc_sock = get_free_port(self._server_address, with_alive_sock=True)
+ self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address, with_alive_sock=True)
else:
self._master_address = None
self._master_port = None
@@ -182,6 +180,12 @@ def get_server_address(self):
assert self._server_port is not None, "http server is not launched, port is None"
return self._server_address, self._server_port
+ @property
+ def lora_as_adapter(self) -> bool:
+ return (
+ self.model_config.lora_rank > 0 or self.model_config.lora.get("rank", 0) > 0
+ ) and not self.model_config.lora.get("merge", False)
+
async def collective_rpc(
self,
method: str | Callable,
@@ -231,8 +235,25 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
set_expandable_segments(True)
quantization = self.config.quantization
+ hf_overrides = {}
+
+ # Handle QAT (Quantization-Aware Training) configuration
+ qat_config_dict = getattr(self.config, "qat", {}) or {}
+ if qat_config_dict.get("enable", False):
+ # QAT uses compressed-tensors quantization, apply patches for dynamic weight loading
+ from verl.utils.qat import QATConfig, apply_qat_patches, load_quantization_config
- if quantization is not None:
+ apply_qat_patches()
+
+ # Load quantization config from JSON file
+ qat_config = QATConfig(**qat_config_dict)
+ quantization_config_dict = load_quantization_config(qat_config)
+ hf_overrides["quantization_config"] = quantization_config_dict
+ quantization = "compressed-tensors"
+
+ logger.info("QAT quantization config injected to vLLM async server")
+ elif quantization is not None:
+ # Handle other quantization methods (fp8, torchao)
_SUPPORTED_QUANTIZATION = ["fp8", "torchao"]
if quantization not in _SUPPORTED_QUANTIZATION:
raise ValueError(f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {quantization}")
@@ -250,19 +271,16 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
"weight_block_size": [128, 128],
"ignored_layers": all_mlp_gate_layers,
}
- fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS)
+ hf_overrides["quantization_config"] = dict(FP8_BLOCK_QUANT_KWARGS)
# Apply vllm fp8 patches
# Will remove the patch after vllm support on-the-fly quant for rollout natively.
apply_vllm_fp8_patches()
# for subprocesses patching
os.environ["VERL_VLLM_FP8_QUANT_ENABLED"] = "1"
- hf_overrides = {}
if quantization is not None and self.config.quantization_config_file is not None:
hf_overrides["quantization_config_file"] = self.config.quantization_config_file
- if quantization == "fp8":
- hf_overrides["quantization_config"] = fp8_block_quant_kwargs
compilation_config = engine_kwargs.pop("compilation_config", None) or {}
if isinstance(compilation_config, str):
compilation_config = json.loads(compilation_config)
@@ -410,6 +428,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
# 3. launch server
if self.node_rank == 0:
self._master_sock.close()
+ self._dp_rpc_sock.close()
+ self._dp_master_sock.close()
await self.run_server(server_args)
else:
# TODO: avoid connect before master_sock close
@@ -456,7 +476,7 @@ async def run_server(self, args: argparse.Namespace):
logger.info(f"Initializing a V1 LLM engine with config: {vllm_config}")
self.engine = engine_client
- self._server_port, self._server_task = await run_unvicorn(app, args, self._server_address)
+ self._server_port, self._server_task = await run_uvicorn(app, args, self._server_address)
async def run_headless(self, args: argparse.Namespace):
"""Run headless server in a separate thread."""
@@ -529,9 +549,7 @@ async def generate(
# Add lora request
lora_request = None
- if (
- self.model_config.lora_rank > 0 or self.model_config.lora.get("rank", 0) > 0
- ) and not self.model_config.lora.get("merge", False):
+ if self.lora_as_adapter:
# Make sure we also check that the lora is already loaded in the engine
lora_loaded = VLLM_LORA_INT_ID in await self.engine.list_loras()
if lora_loaded:
@@ -604,7 +622,12 @@ async def sleep(self):
if self.rollout_mode == RolloutMode.HYBRID:
# Don't use engine.sleep(level=2) here
- await self.engine.collective_rpc("sleep", kwargs={"level": 2})
+ # lora only update adapter weights, so set sleep level to 1
+ if self.lora_as_adapter:
+ sleep_level = 1
+ else:
+ sleep_level = 2
+ await self.engine.collective_rpc("sleep", kwargs={"level": sleep_level})
# clear encoder cache: https://github.com/vllm-project/vllm/pull/33452
# await self.engine.reset_encoder_cache()
@@ -751,9 +774,6 @@ async def abort_request(self, request_id: str, reset_prefix_cache: bool = True)
return {"aborted": False, "request_id": request_id, "error": str(e)}
-_rollout_worker_actor_cls = ray.remote(ServerAdapter)
-
-
class vLLMReplica(RolloutReplica):
def __init__(
self,
@@ -766,16 +786,6 @@ def __init__(
super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model)
self.server_class = ray.remote(vLLMHttpServer)
- def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:
- """Get rollout worker actor class for colocated and standalone mode."""
- worker_dict_cls = RayClassWithInitArgs(
- cls=_rollout_worker_actor_cls,
- config=self.config,
- model_config=self.model_config,
- device_mesh=None,
- )
- return worker_dict_cls
-
async def launch_servers(self):
"""Launch http server in each node."""
assert len(self.workers) == self.world_size, (
@@ -822,8 +832,14 @@ async def launch_servers(self):
node_id=node_id,
soft=False,
),
- runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}},
+ runtime_env={
+ "env_vars": {
+ "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1",
+ "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1",
+ }
+ },
name=name,
+ max_concurrency=self.max_concurrency,
).remote(
config=self.config,
model_config=self.model_config,
diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py
index 75efb81d892..53a433cc51e 100644
--- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py
+++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py
@@ -72,6 +72,7 @@ def __init__(
config: RolloutConfig,
model_config: HFModelConfig,
device_mesh: DeviceMesh,
+ replica_rank: int = -1,
):
super().__init__(config, model_config, device_mesh)
self.server_handle: ray.actor.ActorHandle = None
@@ -83,7 +84,10 @@ def __init__(
* self.config.data_parallel_size
* self.config.pipeline_model_parallel_size
)
- self.replica_rank = rank // rollout_world_size
+ if replica_rank == -1:
+ self.replica_rank = rank // rollout_world_size
+ else:
+ self.replica_rank = replica_rank
self.rollout_rank = rank % rollout_world_size
self.node_rank = self.rollout_rank // local_world_size
@@ -169,7 +173,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
buffer, shm = None, None
if not self.use_shm:
- buffer = torch.empty(bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:0")
+ buffer = torch.empty(bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:{get_device_id()}")
handle = reduce_tensor(buffer)
s.send_pyobj(handle)
else:
@@ -228,6 +232,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
# clean up
s.close()
del buffer
+ gc.collect()
if shm is not None:
shm.close()
shm.unlink()
diff --git a/verl/workers/utils/padding.py b/verl/workers/utils/padding.py
index d4bbf88e8d8..16242e7731f 100644
--- a/verl/workers/utils/padding.py
+++ b/verl/workers/utils/padding.py
@@ -105,7 +105,7 @@ def no_padding_2_padding(tensor: torch.Tensor, data: TensorDict) -> torch.Tensor
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
if max_response_len < 0:
- max_response_len = response_ids.offsets().diff().max().item()
+ max_response_len = response_lens.max().item()
else:
assert not attention_mask.is_nested
prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1)