diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 533f55d332b..2ab905806c4 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -20,6 +20,7 @@ /verl/workers/actor/megatron_actor.py @ISEEKYAN @vermouth1992 /verl/workers/critic/megatron_critic.py @ISEEKYAN @vermouth1992 /verl/workers/megatron_workers.py @ISEEKYAN @vermouth1992 +/verl/experimental @wuxibin89 @ArronHZG /tests/single_controller @zw0610 @wuxibin89 /tests/trainer @eric-haibin-lin @vermouth1992 @tongyx361 @PeterSH6 diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 2db9aa020fa..91c0d21f2a3 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -6,7 +6,7 @@ - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` + - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. diff --git a/.github/workflows/e2e_ascend.yml b/.github/workflows/e2e_ascend.yml index c7aab7be9cd..578a67d3e2b 100644 --- a/.github/workflows/e2e_ascend.yml +++ b/.github/workflows/e2e_ascend.yml @@ -126,6 +126,10 @@ jobs: ray stop --force export PYTHONPATH=$PYTHONPATH:/Megatron-LM USE_DIST_CKPT=True USE_DUMMY_MODEL=True DUMMY_MODEL_CONFIG_PATH=tests/special_e2e/ppo_trainer/expert_parallel/qwen3moe_minimal.json DUMMY_MODEL_PATH=$HOME/dist_ckpt/qwen3_30b_grpo_mindspeed bash tests/special_npu/run_qwen3_30b_grpo_mindspeed.sh + - name: Running the E2E test with fully_async_policy algorithm (FSDP2) + run: | + ray stop --force + bash tests/special_npu/run_fully_async_policy.sh vlm_rl_job: if: github.repository_owner == 'verl-project' diff --git a/.github/workflows/e2e_one_step_off_policy_ascend.yml b/.github/workflows/e2e_one_step_off_policy_ascend.yml index 3738ede9331..f1b0054a33a 100644 --- a/.github/workflows/e2e_one_step_off_policy_ascend.yml +++ b/.github/workflows/e2e_one_step_off_policy_ascend.yml @@ -68,7 +68,7 @@ on: # Entrypoints - ".github/workflows/e2e_one_step_off_policy_ascend.yml" - "examples/data_preprocess/gsm8k.py" - - "tests/special_e2e/run_one_step_off_policy.sh" + - "tests/special_npu/run_one_step_off_policy.sh" # Cancel jobs on the same ref if a new one is triggered concurrency: @@ -122,7 +122,7 @@ jobs: - name: Running the E2E test with one_step_off_policy algorithm (FSDP2) run: | ray stop --force - bash tests/special_e2e/run_one_step_off_policy.sh + bash tests/special_npu/run_one_step_off_policy.sh # Test Megatron strategy e2e_one_step_off_policy_megatron_ascend: @@ -167,4 +167,4 @@ jobs: run: | ray stop --force export PYTHONPATH=$PYTHONPATH:/Megatron-LM - bash tests/special_e2e/run_one_step_off_policy.sh + bash tests/special_npu/run_one_step_off_policy.sh diff --git a/.github/workflows/e2e_ppo_trainer_veomni_vllm.yml b/.github/workflows/e2e_ppo_trainer_veomni_vllm.yml index 04be1af3b15..f2cdacd0f31 100644 --- a/.github/workflows/e2e_ppo_trainer_veomni_vllm.yml +++ b/.github/workflows/e2e_ppo_trainer_veomni_vllm.yml @@ -134,7 +134,7 @@ jobs: - name: Running GEO3K E2E training tests on 8 L20 GPUs with veomni engine (FSDP_SIZE=8, USP=1) run: | ray stop --force - MODEL_ID=Qwen/Qwen3-VL-2B-Instruct TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/gsm8k/test.parquet VAL_BEFORE_TRAIN=True NUM_GPUS=8 FSDP_SIZE=8 SP_SIZE=1 EP_SIZE=1 VERL_EXP_NAME="qwen3-2b-vl-function-reward-minimal-fsdp-size8" bash tests/special_e2e/run_ppo_trainer_veomni.sh + MODEL_ID=Qwen/Qwen3-VL-2B-Instruct TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/gsm8k/test.parquet VAL_BEFORE_TRAIN=True NUM_GPUS=8 FSDP_SIZE=4 SP_SIZE=2 EP_SIZE=1 VERL_EXP_NAME="qwen3-2b-vl-function-reward-minimal-fsdp-size8" bash tests/special_e2e/run_ppo_trainer_veomni.sh cleanup: runs-on: ubuntu-latest diff --git a/.github/workflows/e2e_sft_llm.yml b/.github/workflows/e2e_sft_llm.yml index 448e433bdea..66515a422d2 100644 --- a/.github/workflows/e2e_sft_llm.yml +++ b/.github/workflows/e2e_sft_llm.yml @@ -110,7 +110,7 @@ jobs: - name: Prepare gsm8k dataset run: | ray stop --force - python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k + python3 examples/data_preprocess/gsm8k_multiturn_sft.py --local_dataset_path ${HOME}/models/hf_data/gsm8k - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm run: | ray stop --force @@ -123,10 +123,6 @@ jobs: run: | ray stop --force SP_SIZE=2 bash tests/special_e2e/sft/run_sft.sh - - name: Check loss difference between sequence parallel vs. default implementation - run: | - ray stop --force - ENTRYPOINT="tests/special_e2e/sft/test_sp_loss_match.py" SP_SIZE=2 bash tests/special_e2e/sft/run_sft.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with sequence parallism and liger run: | ray stop --force @@ -140,10 +136,6 @@ jobs: ray stop --force LORA_RANK=32 RESUME_MODE=auto TOTAL_TRAIN_STEP=2 bash tests/special_e2e/sft/run_sft.sh # TODO: multiturn - - name: Prepare gsm8k dataset - run: | - ray stop --force - python3 examples/data_preprocess/gsm8k_multiturn_sft.py --local_dataset_path ${HOME}/models/hf_data/gsm8k - name: Running GSM8K E2E training tests with multiturn and various configs and compare results run: | bash tests/special_e2e/sft/test_sft_engine_all.sh diff --git a/.github/workflows/e2e_sft_llm_ascend.yml b/.github/workflows/e2e_sft_llm_ascend.yml index 825fb265647..4ccd074cefa 100644 --- a/.github/workflows/e2e_sft_llm_ascend.yml +++ b/.github/workflows/e2e_sft_llm_ascend.yml @@ -109,7 +109,7 @@ jobs: ln -s /root/.cache/models ~/models - name: Prepare gsm8k dataset run: | - python examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k + python3 examples/data_preprocess/gsm8k_multiturn_sft.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k - name: Running GSM8K E2E training tests on 8 NPUs with rmpad using function rm run: | ray stop --force @@ -122,10 +122,6 @@ jobs: run: | ray stop --force SP_SIZE=2 bash tests/special_e2e/sft/run_sft.sh - - name: Check loss difference between sequence parallel vs. default implementation - run: | - ray stop --force - ENTRYPOINT="tests/special_e2e/sft/test_sp_loss_match.py" SP_SIZE=2 bash tests/special_e2e/sft/run_sft.sh - name: Running GSM8K E2E training tests with LoRA run: | ray stop --force @@ -134,11 +130,6 @@ jobs: run: | ray stop --force LORA_RANK=32 RESUME_MODE=auto TOTAL_TRAIN_STEP=2 bash tests/special_e2e/sft/run_sft.sh - # TODO: multiturn - - name: Prepare gsm8k dataset - run: | - ray stop --force - python3 examples/data_preprocess/gsm8k_multiturn_sft.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k - name: Running GSM8K E2E training tests with multiturn and various configs and compare results run: | export PYTHONPATH=$PYTHONPATH:/Megatron-LM diff --git a/.github/workflows/e2e_transferqueue.yml b/.github/workflows/e2e_transferqueue.yml deleted file mode 100644 index 309f715be1b..00000000000 --- a/.github/workflows/e2e_transferqueue.yml +++ /dev/null @@ -1,172 +0,0 @@ -# # Tests layout - -# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: -# - `tests/trainer` for testing functionality related to `verl/trainer` -# - `tests/models` for testing functionality related to `verl/models` -# - ... - -# There are a few folders with `special_` prefix, created for special purposes: -# - `special_distributed`: unit tests that must run with multiple GPUs -# - `special_e2e`: end-to-end tests with training/generation scripts -# - `special_npu`: tests for NPUs -# - `special_sanity`: a suite of quick sanity tests -# - `special_standalone`: a set of test that are designed to run in dedicated environments - -# Accelerators for tests -# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. -# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. - -# # Workflow layout - -# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: -# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` -# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` -# 3. End-to-end tests: `e2e_*.yml` -# 4. Unit tests -# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` -# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. -# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when -# - new workflow yaml is added to `.github/workflows` -# - new tests are added to workflow mentioned in 2. - -name: e2e_transferqueue - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - # For push, for now only anti-patterns are specified so it is more conservative - # and achieves higher coverage. - push: - branches: - - main - - v0.* - paths: - - "**/*.py" - - "!**/*.md" - - "!**/*.sh" - # Other entrypoints - - "!examples/*trainer*" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - - "verl/experimental/transfer_queue/**" - pull_request: - branches: - - main - - v0.* - paths: - - "**/*.py" - - "!**/*.md" - - "!**/*.sh" - # Other entrypoints - - "!examples/**" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - # Home - - "verl/experimental/transfer_queue" - # Entrypoints - - ".github/workflows/e2e_transferqueue.yml" - - "examples/data_preprocess/gsm8k.py" - - "tests/special_e2e/run_transferqueue.sh" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -env: - IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:vllm012.dev3" - DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" - -jobs: - setup: - if: github.repository_owner == 'verl-project' - runs-on: ubuntu-latest - outputs: - runner-label: ${{ steps.create-runner.outputs.runner-label }} - mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} - steps: - - uses: actions/checkout@v4 - - id: create-runner - uses: volcengine/vemlp-github-runner@v1 - with: - mode: "create" - faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" - mlp-image: "${{ env.IMAGE }}" - - # Test FSDP strategy - e2e_transferqueue_fsdp: - needs: setup - runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] - timeout-minutes: 10 # Increase timeout for async training - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - ACTOR_STRATEGY: "fsdp" - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -r requirements-test.txt - pip3 install --no-deps -e . - pip3 install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.5 - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k - - name: Running the E2E test with TransferQueue (FSDP), enable zero copy serialization - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=True - bash tests/special_e2e/run_transferqueue.sh - - # Test Megatron strategy - e2e_transferqueue_megatron: - needs: setup - runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] - timeout-minutes: 10 # Increase timeout for async training - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - ACTOR_STRATEGY: "megatron" - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -r requirements-test.txt - pip3 install --no-deps -e . - pip3 install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.5 - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k - - name: Running the E2E test with TransferQueue (Megatron), disable zero copy serialization - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=False - bash tests/special_e2e/run_transferqueue.sh - - cleanup: - runs-on: ubuntu-latest - needs: [setup, e2e_transferqueue_fsdp, e2e_transferqueue_megatron] - if: always() - steps: - - id: destroy-runner - uses: volcengine/vemlp-github-runner@v1 - with: - mode: "destroy" - faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" - mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" diff --git a/.github/workflows/e2e_transferqueue_ascend.yml b/.github/workflows/e2e_transferqueue_ascend.yml deleted file mode 100644 index ca1dc2de87e..00000000000 --- a/.github/workflows/e2e_transferqueue_ascend.yml +++ /dev/null @@ -1,174 +0,0 @@ -# # Tests layout - -# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: -# - `tests/trainer` for testing functionality related to `verl/trainer` -# - `tests/models` for testing functionality related to `verl/models` -# - ... - -# There are a few folders with `special_` prefix, created for special purposes: -# - `special_distributed`: unit tests that must run with multiple GPUs -# - `special_e2e`: end-to-end tests with training/generation scripts -# - `special_npu`: tests for NPUs -# - `special_sanity`: a suite of quick sanity tests -# - `special_standalone`: a set of test that are designed to run in dedicated environments - -# Accelerators for tests -# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. -# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. - -# # Workflow layout - -# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: -# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` -# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` -# 3. End-to-end tests: `e2e_*.yml` -# 4. Unit tests -# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` -# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. -# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when -# - new workflow yaml is added to `.github/workflows` -# - new tests are added to workflow mentioned in 2. - -name: e2e_transferqueue_ascend - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - # For push, for now only anti-patterns are specified so it is more conservative - # and achieves higher coverage. - push: - branches: - - main - - v0.* - paths: - - "**/*.py" - - "!**/*.md" - - "!**/*.sh" - # Other entrypoints - - "!examples/*trainer*" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - - "verl/experimental/transfer_queue/**" - pull_request: - branches: - - main - - v0.* - paths: - - "**/*.py" - - "!**/*.md" - - "!**/*.sh" - # Other entrypoints - - "!examples/**" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - # Home - - "verl/experimental/transfer_queue" - # Entrypoints - - ".github/workflows/e2e_transferqueue_ascend.yml" - - "examples/data_preprocess/gsm8k.py" - - "tests/special_e2e/run_transferqueue.sh" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - # Test FSDP strategy - e2e_transferqueue_fsdp_ascend: - if: github.repository_owner == 'verl-project' - runs-on: linux-aarch64-a2-8 - timeout-minutes: 60 # Increase this timeout value as needed - container: - image: swr.ap-southeast-1.myhuaweicloud.com/base_image/ascend-ci/verl/verl:verl-8.3.rc1-910b-ubuntu22.04-py3.11-latest - options: >- - --shm-size 16g - env: - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - ACTOR_STRATEGY: "fsdp" - steps: - - name: Check npu and CANN info - run: | - cat /usr/local/Ascend/ascend-toolkit/latest/"$(uname -i)"-linux/ascend_toolkit_install.info - npu-smi info - - name: Check initial pip list from image - run: | - pip list - - name: Checkout verl-project/verl repo - uses: actions/checkout@v4 - with: - fetch-depth: 0 - clean: true - - name: Install the current repository - run: | - pip install -r requirements-npu.txt - pip install --no-deps -e . - pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.5 - - name: Check final pip list - run: | - pip list - - name: Prepare weights - run: | - ln -s /root/.cache/models ~/models - - name: Prepare GSM8K dataset - run: | - python examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k - - name: Running the E2E test with TransferQueue (FSDP), enable zero copy serialization - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=True - bash tests/special_e2e/run_transferqueue.sh - - # Test Megatron strategy - e2e_transferqueue_megatron_ascend: - if: github.repository_owner == 'verl-project' - runs-on: linux-aarch64-a2-8 - timeout-minutes: 60 # Increase this timeout value as needed - container: - image: swr.ap-southeast-1.myhuaweicloud.com/base_image/ascend-ci/verl/verl:verl-8.3.rc1-910b-ubuntu22.04-py3.11-latest - options: >- - --shm-size 16g - env: - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - ACTOR_STRATEGY: "megatron" - steps: - - name: Check npu and CANN info - run: | - cat /usr/local/Ascend/ascend-toolkit/latest/"$(uname -i)"-linux/ascend_toolkit_install.info - npu-smi info - - name: Check initial pip list from image - run: | - pip list - - name: Checkout verl-project/verl repo - uses: actions/checkout@v4 - with: - fetch-depth: 0 - clean: true - - name: Install the current repository - run: | - pip install -r requirements-npu.txt - pip install --no-deps -e . - pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.5 - - name: Check final pip list - run: | - pip list - - name: Prepare weights - run: | - ln -s /root/.cache/models ~/models - - name: Prepare GSM8K dataset - run: | - python examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k - - name: Running the E2E test with TransferQueue (Megatron), disable zero copy serialization - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=False - export PYTHONPATH=$PYTHONPATH:/Megatron-LM - bash tests/special_e2e/run_transferqueue.sh diff --git a/.github/workflows/gpu_unit_tests.yml b/.github/workflows/gpu_unit_tests.yml index d16075e19ea..eaafeadb15e 100644 --- a/.github/workflows/gpu_unit_tests.yml +++ b/.github/workflows/gpu_unit_tests.yml @@ -108,7 +108,7 @@ jobs: pip3 install hf_transfer pip3 install -r requirements-test.txt pip3 install --no-deps -e . - pip3 install cupy-cuda12x pytest-asyncio + pip3 install cupy-cuda12x==13.6.0 pytest-asyncio pip3 install --ignore-installed blinker pip3 install --ignore-installed mlflow "numpy<2.0" - name: Run all GPU unit tests diff --git a/.github/workflows/sgl.yml b/.github/workflows/sgl.yml index f9158731bac..2c928a9527e 100644 --- a/.github/workflows/sgl.yml +++ b/.github/workflows/sgl.yml @@ -113,7 +113,7 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install cupy-cuda12x pytest-asyncio + pip3 install cupy-cuda12x==13.6.0 pytest-asyncio pip3 install hf_transfer fastmcp pytest-asyncio pip3 install -r requirements-test.txt pip3 install --no-deps -e . @@ -144,7 +144,7 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install cupy-cuda12x pytest-asyncio + pip3 install cupy-cuda12x==13.6.0 pytest-asyncio pip3 install hf_transfer fastmcp pytest-asyncio pip3 install -r requirements-test.txt pip3 install --no-deps -e . diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index 84cdb45c3f6..15c51678030 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -144,7 +144,7 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install cupy-cuda12x pytest-asyncio + pip3 install pytest-asyncio pip3 install -r requirements-test.txt pip3 install --no-deps -e . pip3 install --upgrade "transformers<5.0" diff --git a/.gitignore b/.gitignore index 62d4dcfc815..e6c0f5a08e3 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ **/playground **/wandb +/pyrightconfig.json + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index cea8b9945e9..a7d04fc5ba5 100644 --- a/README.md +++ b/README.md @@ -237,6 +237,7 @@ Welcome to register your awesome project build with `verl` for other developers' - [DeepResearcher](https://github.com/GAIR-NLP/DeepResearcher): Scaling deep research via reinforcement learning in real-world environments ![GitHub Repo stars](https://img.shields.io/github/stars/GAIR-NLP/DeepResearcher) - [VAGEN](https://github.com/RAGEN-AI/VAGEN): Training VLM agents with multi-turn reinforcement learning ![GitHub Repo stars](https://img.shields.io/github/stars/RAGEN-AI/VAGEN) - [RM-R1](https://arxiv.org/abs/2505.02387): RL training of reasoning reward models ![GitHub Repo stars](https://img.shields.io/github/stars/RM-R1-UIUC/RM-R1) +- [Dr. MAS](https://arxiv.org/pdf/2602.08847): Stable **end-to-end RL** post-training for **multi-agent LLM systems** ![GitHub Repo stars](https://img.shields.io/github/stars/langfengQ/DrMAS) - [LUFFY](https://arxiv.org/pdf/2504.14945): Learning to Reason under Off-Policy Guidance![GitHub Repo stars](https://img.shields.io/github/stars/ElliottYan/LUFFY) - [DeepMath](https://github.com/zwhe99/DeepMath): DeepMath-103K data and series models for math reasoning![GitHub Repo stars](https://img.shields.io/github/stars/zwhe99/DeepMath) - [PACS](https://github.com/ritzz-ai/PACS): Implicit Actor Critic Coupling via a Supervised Learning Framework for RLVR ![GitHub Repo stars](https://img.shields.io/github/stars/ritzz-ai/PACS) @@ -283,6 +284,7 @@ Welcome to register your awesome project build with `verl` for other developers' - [NoisyRollout](https://github.com/NUS-TRAIL/NoisyRollout): Reinforcing Visual Reasoning with Data Augmentation ![GitHub Repo stars](https://img.shields.io/github/stars/NUS-TRAIL/NoisyRollout) - [SPEAR](https://github.com/TencentYoutuResearch/SPEAR): **Self-imitation** with **Progressive Exploration** for Agentic Reinforcement Learning (ICLR 2026) ![GitHub Repo stars](https://img.shields.io/github/stars/TencentYoutuResearch/SPEAR) - [RuleReasoner](https://github.com/bigai-nlco/RuleReasoner): **RuleReasoner:** Reinforced Rule-based Reasoning via **Domain-aware Dynamic Sampling** (ICLR 2026) ![GitHub Repo stars](https://img.shields.io/github/stars/bigai-nlco/RuleReasoner) +- [MetaphorStar](https://metaphorstar.github.io/): **Image Metaphor** Understanding and Reasoning with End-to-End **Visual Reinforcement Learning** ![GitHub Repo stars](https://img.shields.io/github/stars/MING-ZCH/MetaphorStar) ## Contribution Guide diff --git a/docker/Dockerfile.stable.vllm b/docker/Dockerfile.stable.vllm index 158eccac580..701713d1121 100644 --- a/docker/Dockerfile.stable.vllm +++ b/docker/Dockerfile.stable.vllm @@ -32,6 +32,9 @@ RUN pip install torch==2.9.1 torchvision torchaudio --index-url https://download RUN sed -i '/nvidia-cudnn-cu12/d' /usr/local/lib/python3.12/dist-packages/torch-2.9.1+cu129.dist-info/METADATA RUN pip install --no-deps --force-reinstall nvidia-cudnn-cu12==9.16.0.29 +# NOTE: This installs the `vllm` source code in `/vllm`. +# This might break the (based)pyright type checking. To fix it, add `/vllm` to `extraPaths` in `pyrightconfig.json`. +# c.f. https://docs.basedpyright.com/latest/configuration/config-files/ RUN git clone --depth 1 -b v0.12.0 https://github.com/vllm-project/vllm.git && \ cd vllm && \ find requirements -name "*.txt" -print0 | xargs -0 sed -i '/torch/d' && \ diff --git a/docker/ascend/Dockerfile.ascend_8.3.rc1_a2 b/docker/ascend/Dockerfile.ascend_8.3.rc1_a2 index 2d9993cd921..baa55c0e348 100644 --- a/docker/ascend/Dockerfile.ascend_8.3.rc1_a2 +++ b/docker/ascend/Dockerfile.ascend_8.3.rc1_a2 @@ -18,7 +18,7 @@ RUN ARCH=$(uname -m) && \ fi && \ # Clone libs git clone --depth 1 --branch v0.11.0 https://github.com/vllm-project/vllm.git && \ - git clone --depth 1 --branch v0.11.0rc1 https://github.com/vllm-project/vllm-ascend.git && \ + git clone --depth 1 --branch v0.11.0 https://github.com/vllm-project/vllm-ascend.git && \ git clone https://gitcode.com/Ascend/MindSpeed.git && \ cd MindSpeed && git checkout f2b0977e && cd .. && \ git clone --depth 1 --branch core_v0.12.1 https://github.com/NVIDIA/Megatron-LM.git diff --git a/docker/ascend/Dockerfile.ascend_8.3.rc1_a3 b/docker/ascend/Dockerfile.ascend_8.3.rc1_a3 index bead95cea6c..f5d74a7e70a 100644 --- a/docker/ascend/Dockerfile.ascend_8.3.rc1_a3 +++ b/docker/ascend/Dockerfile.ascend_8.3.rc1_a3 @@ -18,7 +18,7 @@ RUN ARCH=$(uname -m) && \ fi && \ # Clone libs git clone --depth 1 --branch v0.11.0 https://github.com/vllm-project/vllm.git && \ - git clone --depth 1 --branch v0.11.0rc1 https://github.com/vllm-project/vllm-ascend.git && \ + git clone --depth 1 --branch v0.11.0 https://github.com/vllm-project/vllm-ascend.git && \ git clone https://gitcode.com/Ascend/MindSpeed.git && \ cd MindSpeed && git checkout f2b0977e && cd .. && \ git clone --depth 1 --branch core_v0.12.1 https://github.com/NVIDIA/Megatron-LM.git diff --git a/docs/advance/fully_async.md b/docs/advance/fully_async.md index a2b7ccb3aea..e7c8962ad9c 100644 --- a/docs/advance/fully_async.md +++ b/docs/advance/fully_async.md @@ -106,9 +106,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev | `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization | | `async_training.staleness_threshold` | Freshness control | | `async_training.partial_rollout` | Whether to perform partial_rollout | -| `async_training.checkpoint_engine.enable` | Whether to use checkpoint_engine for accelerating, default `True` | -| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | When use checkpoint_engine, whether to overlap broadcast and load_weights, default `False` | -| `async_training.checkpoint_engine.device_buffer_size_M` | When use checkpoint_engine, the user-specific bucket size (MB), default `4096` | | `async_training.use_trainer_do_validate` | Whether use trainer node to do validate process, default `False` | **Further Explanation:** @@ -182,27 +179,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev mode d (async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`. -* `async_training.checkpoint_engine.enable` - - Enabling the checkpoint engine generally reduces synchronization time overhead by more than 60% compared to - the original per-tensor parameter synchronization method. However, assembling buckets incurs additional - temporary GPU memory overhead. - -* `async_training.checkpoint_engine.overlap_broadcast_and_consume` - - Enabling pipeline between the broadcast and load_weights parameters will allocate additional GPU memory. - Since the main time consumption for parameter synchronization is not in the broadcast and load_weights phases, - but in the parameter generation phase (by megatron or FSDP), this option is off by default. - -* `async_training.checkpoint_engine.device_buffer_size_M` - - It controls the size of the memory buffer used for synchronization when the checkpoint-engine is enabled. - The actual `bucket_size` = `max(device_buffer_size_M, maximum parameter tensor size)`. - * When enable `overlap_broadcast_and_consume`, the additional device memory overhead of - trainer rank is `3 * bucket_size`and rollout rank is `2 * bucket_size`。 - * When disable `overlap_broadcast_and_consume`, the additional device memory overhead of - trainer rank is `2 * bucket_size`and rollout rank is `1 * bucket_size`。 - * `async_training.use_trainer_do_validate` It controls whether to use the trainer's `do_validate` method for validation. diff --git a/docs/advance/mtp.md b/docs/advance/mtp.md index b4c5a25c631..5f1698d3ddc 100644 --- a/docs/advance/mtp.md +++ b/docs/advance/mtp.md @@ -2,19 +2,21 @@ **Author**: `https://github.com/meituan-search` -Last updated: 01/30/2026 +Last updated: 02/15/2026 # 1. Scope of Support Currently, RL training can be performed on mimo-7B-RL, Qwen-next, and Deepseek series models based on the MTP architecture. The support rules for training and inference engines are as follows: -- **Training Engine**: Only supports the `mbridge + megatron` combination; other training engines are not compatible at this time; +- **Training Engine**: Only supports the `mbridge/Megatron-Bridge + megatron` combination; other training engines are not compatible at this time; - **Inference Engine**: Compatible with all engines, but the model must be in the corresponding engine's compatibility list; - **Dependency Versions**: - - mbridge: Use the specified branch: [https://github.com/ArronHZG/mbridge/tree/feature/verl_mtp](https://github.com/ArronHZG/mbridge/tree/feature/verl_mtp) (will be merged into the main branch in the future); + - mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (will be merged into the main branch in the future); + + - Megatron-Bridge: Apply the patches and review suggestions from PR if you want to try out mimo-7B-RL: [#2387](https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/2387) (will be merged into the main branch in the future); - megatron: Use the latest dev version (commit: [23e092f41ec8bc659020e401ddac9576c1cfed7e](https://github.com/NVIDIA/Megatron-LM/tree/23e092f41ec8bc659020e401ddac9576c1cfed7e)), which supports MTP + CP training methods. diff --git a/docs/algo/dppo.md b/docs/algo/dppo.md new file mode 100644 index 00000000000..aab22bf7fa6 --- /dev/null +++ b/docs/algo/dppo.md @@ -0,0 +1,96 @@ +# Divergence Proximal Policy Optimization (DPPO) + +Last updated: 02/25/2026. + + +
+ +## Rethinking the Trust Region in LLM Reinforcement Learning + +[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white )](https://arxiv.org/pdf/2602.04879) +[![Github](https://img.shields.io/badge/Stable_RL-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/sail-sg/Stable-RL) +[![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/QPHutu/status/2019435642539897303) + +
+ + +## ✨Getting started + +1. Prepare the datasets by running [prepare_dapo_data.sh](https://github.com/verl-project/verl-recipe/blob/3490a22a0a3adeb7e4787fe70b1060b642efbae4/dapo/prepare_dapo_data.sh): + +```bash +bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default +``` + +2. Prepare the model: + +```bash +hf download Qwen/Qwen3-30B-A3B-Base --local-dir ${HOME}/verl/models/Qwen3-30B-A3B-Base +``` + +3. Run the script: + +```bash +# run DPPO-Binary-KL +LOSS_MODE=dppo_kl bash examples/dppo_trainer/run_qwen30b_dppo.sh + +# run DPPO-Binary-TV +LOSS_MODE=dppo_tv bash examples/dppo_trainer/run_qwen30b_dppo.sh + +# run GRPO baseline +LOSS_MODE=vanilla CLIP_LOW=0.2 CLIP_HIGH=0.2 bash examples/dppo_trainer/run_qwen30b_dppo.sh +# or GRPO with clip higher +LOSS_MODE=vanilla CLIP_LOW=0.2 CLIP_HIGH=0.28 bash examples/dppo_trainer/run_qwen30b_dppo.sh +``` + +## 📖Introduction + +
+ issue +
+ +Comparison of **PPO** and the proposed **DPPO** (the Binary-TV variant). **(Left)** The surrogate objective and corresponding masks for PPO and DPPO. PPO (and variants like GRPO) employs a heuristic mask based on the probability ratio. In contrast, DPPO utilizes a more principled mask based on a direct approximation of policy divergence (e.g., Total Variation), ensuring updates stay within a theoretically grounded trust region. **(Right)** Experimental results on the AIME24 using Qwen3-30B-A3B-Base. DPPO significantly outperforms GRPO baselines, achieving superior training stability and final performance even without rollout routing replay (R3). + +
+ issue +
+ +DPPO variants achieve stable training while controlling the training-inference mismatch at a low level. In contrast, methods without a trust region (PG-IS, CISPO) or with a misspecified one (MiniRL) suffer from growing mismatch and eventual collapse. + +
+ issue +
+ +The plots show numerical differences between a training and an inference engine for Qwen3-30B-A3B-Base with identical parameters. **(Left)** The probability ratio (used in PPO) is highly volatile for low-probability tokens. **(Right)** In contrast, the TV divergence is more stable. This highlights a key flaw of PPO's clipping mechanism: it **over-penalizes low-probability tokens**, which can slow down learning; and **under-penalizes high-probability tokens**, which can permit large, destabilizing updates. + + +
+ issue +
+ +The most frequently clipped tokens (by GRPO) are important to the reasoning task! +They are dominated by: +- numbers, like 1, 4 +- mathematical symbols, like +, -, = +- reasoning and structural Words: Wait, Thus, Next + +## Top-K divergence approximation + +We only implement the DPPO-Binary-TV/DPPO-Binary-KL here due to their simplicity. + +For the TopK divergence approximation, please refer to the [the original repo](https://github.com/sail-sg/Stable-RL) for a complete implementation. + +## Citation +If you find our works useful for your research, please consider citing: + +```bibtex +@article{qi2026dppo, + title={Rethinking the Trust Region in LLM Reinforcement Learning}, + author={Qi, Penghui and Zhou, Xiangxin and Liu, Zichen and Pang, Tianyu and Du, Chao and Lin, Min and Lee, Wee Sun}, + journal={arXiv preprint arXiv:2602.04879}, + year={2026} +} +``` + +## 🌻Acknowledgement +We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) and [sglang](https://github.com/sgl-project/sglang) for inference. Our models are trained primarily on [Qwen3 family](https://huggingface.co/collections/Qwen/qwen3). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions! diff --git a/docs/algo/otb.md b/docs/algo/otb.md index 288eb71bd69..ad39375f090 100644 --- a/docs/algo/otb.md +++ b/docs/algo/otb.md @@ -1,18 +1,19 @@ # Optimal Token Baseline (OTB) -Last updated: 12/25/2025. +Last updated: 02/23/2026. -Optimal Token Baseline (OTB) is dynamic token-level baseline for variance reduction. It weights updates based on "Realized Energy"—essentially, how much uncertainty has accumulated up to that specific token. It downweights the noisy parts and trusts the clear signals. Read [Optimal Token Baseline blog](https://richardli.xyz/optimal-token-baseline) for more details. +📝 [ArXiv](https://www.arxiv.org/abs/2602.07078) | 📒 [Blog](https://richardli.xyz/optimal-token-baseline) | 🤗 [Datasets](https://huggingface.co/datasets/Jiawei415/DPAO_filter) -## The method: OTB +Optimal Token Baseline (OTB) is a dynamic token-level baseline for gradient variance reduction in policy-gradient reinforcement learning. It weights updates with the "Realized Energy" statistic that tracks how much uncertainty has accumulated up to each token, so noisy regions get downweighted while confident regions carry more weight. -- OTB builds a _dynamic_ baseline that adapts to each token by tracking the “Realized Energy”—the uncertainty that has accumulated up to that token. It downweights the noisy parts and trusts the clear signals. -- Unlike standard group means (which average over the padding `EOS` token ineffectively), OTB handles this naturally by computing baselines only over valid tokens. +## Key properties + +- _Token-level baselines:_ OTB adapts per token by tracking realized energy, avoiding the padding artifacts that appear when group means dilute the signal with `EOS` tokens. +- _Forward-only overhead:_ The realized-energy statistic is computed via the **Logit-Gradient Proxy**, so OTB requires no extra backward passes or gradient-norm kernels. ## Logit-Gradient Proxy -- Computing true uncertainty requires expensive backward passes (calculating gradient norms per token). Instead, OTB introduces the **Logit-Gradient Proxy**: the realized energy can be estimated entirely from forward probabilities. -- This means zero extra backward calls and effectively no additional runtime overhead. +Computing true uncertainty per token would normally mandate per-token backward passes. OTB sidesteps this by estimating realized energy entirely from forward probabilities, so it introduces negligible runtime overhead in practice. ## Mechanics at a glance @@ -38,7 +39,7 @@ The final advantage is `(G_t - B*_t) · mask_t`, so padding tokens stay at zero. - `actor_rollout_ref.actor.calculate_sum_pi_squared: true` (mandatory). - `actor_rollout_ref.model.use_fused_kernels: false` (required until fused kernels emit logits). -- `algorithm.adv_estimator: optimal_token_baseline`. +- `algorithm.adv_estimator: optimal_token_baseline` for single-turn RL and `tir_optimal_token_baseline` for multi-turn RL. - Group sampling (`actor_rollout_ref.rollout.n > 1`) to unlock OTB’s variance reduction; with `n=1` the baseline collapses to returns. Example OmegaConf overlay: @@ -57,7 +58,7 @@ actor_rollout_ref: ## Example script -- `examples/otb_trainer/run_qwen2_5-7b.sh`. +See `examples/otb_trainer/run_qwen2_5-7b.sh` for a reference training loop. ## Gradient Variance Proxy Metrics @@ -96,9 +97,6 @@ where `Ŵ(τ)` is the realized energy built. Given a mini-batch `{τ_i}` of siz Var_proxy = (1/(N-1)) · (P_total - S) ``` -`verl/trainer/ppo/metric_utils.py#L306` implements these diagnostics via `compute_variance_proxy_metrics`, emitting -`variance_proxy/proxy1_signal_strength`, -`variance_proxy/proxy2_total_power`, and -`variance_proxy/proxy3_pure_noise`. +`verl/trainer/ppo/metric_utils.py#L306` implements these diagnostics via `compute_variance_proxy_metrics`, emitting `variance_proxy/proxy1_signal_strength`, `variance_proxy/proxy2_total_power`, and `variance_proxy/proxy3_pure_noise`. Tracking these metrics provides a forward-only, low-overhead view of gradient health for any advantage estimator that supplies `sum_pi_squared`. diff --git a/docs/algo/rollout_corr.md b/docs/algo/rollout_corr.md index 8569b243a9e..066919b67d5 100644 --- a/docs/algo/rollout_corr.md +++ b/docs/algo/rollout_corr.md @@ -29,6 +29,13 @@ This document provides a comprehensive overview of the Rollout Correction implem month = sep, url = {https://richardli.xyz/rl-collapse} } + +@article{li2025trust, + title={Trust Region Masking for Long-Horizon LLM Reinforcement Learning}, + author={Li, Yingru and Liu, Jiacai and Xu, Jiawei and Tong, Yuxuan and Li, Ziniu and Liu, Qian and Wang, Baoxiang}, + journal={arXiv preprint arXiv:2512.23075}, + year={2025} +} ``` ### Blog Series @@ -37,6 +44,7 @@ This document provides a comprehensive overview of the Rollout Correction implem - [Part 1: Why Mismatch Breaks LLM-RL](https://richardli.xyz/rl-collapse-1) (analytical framework using TV distance for bias and χ²-divergence for variance) - [Part 2: The Gradient Estimator Trials](https://richardli.xyz/rl-collapse-2) (token-level vs sequence-level correction bias-variance tradeoff) - [Part 3: When Math Meets Reality—Toxic Tails and Length Traps](https://richardli.xyz/rl-collapse-3) (why rejection over clipping, and geometric-level RS) +- Latest Paper: https://arxiv.org/abs/2512.23075 ## Overview @@ -85,11 +93,11 @@ This critical implementation mistake that leads to RL training collapse was iden **Mathematically correct approaches:** -- **Decoupled mode**: Three policies (π*rollout, π_old, π*θ) with IS correction from π_rollout to π_old -- **Bypass mode**: Two policies (π*rollout = π_old, π*θ) using actual rollout policy as PPO anchor -- **Bypass + Policy Gradient mode**: Two policies (π*rollout, π*θ) with IS/RS correction and no PPO clipping +- **Decoupled mode**: Three policies (π_rollout, π_old, π_θ) with IS correction from π_rollout to π_old +- **Bypass mode**: Two policies (π_rollout = π_old, π_θ) using actual rollout policy as PPO anchor +- **Bypass + Policy Gradient mode**: Two policies (π_rollout, π_θ) with IS/RS correction and no PPO clipping -See [Mathematical Formulations](rollout_corr_math.md#38-common-implementation-mistake) for detailed explanation. +See [Mathematical Formulations](rollout_corr_math.md#37-common-implementation-mistake) for detailed explanation. ### Key Design Principle: Separation of IS Weights and Rejection Sampling @@ -97,7 +105,7 @@ The implementation cleanly separates two orthogonal mechanisms: 1. **IS Weights** (`rollout_is_weights`): Continuous reweighting for gradient correction - - Policy ratio: π*old/π_rollout (decoupled) or π*θ/π_rollout (bypass) + - Policy ratio: π_old/π_rollout (decoupled) or π_θ/π_rollout (bypass) - **Safety-bounded**: Clamped to [exp(-20), exp(20)] ≈ [2e-9, 5e8] to prevent overflow - Token level: Bounds per-token ratios - Sequence level: Bounds product of ratios (broadcast to all tokens) @@ -109,7 +117,6 @@ The implementation cleanly separates two orthogonal mechanisms: - Creates binary mask: 1 = keep, 0 = reject - Rejects tokens/sequences with IS ratios outside [lower_threshold, upper_threshold] - Modifies response_mask to exclude rejected samples from training - - Used for loss aggregation (rejected samples don't contribute to gradients) This separation ensures: @@ -142,13 +149,13 @@ config = RolloutCorrectionConfig.decoupled_k3_rs_token_tis() # K3-RS + Token- # === Bypass PPO mode (2 policies: π_rollout = π_old, π_θ) - fast === # PPO ratio handles IS, so no explicit IS weights needed config = RolloutCorrectionConfig.bypass_ppo_clip() # PPO-clip only -config = RolloutCorrectionConfig.bypass_ppo_clip_geo_rs() # PPO-clip + Geo-RS (ratio) +config = RolloutCorrectionConfig.bypass_ppo_clip_geo_rs() # PPO-clip + Geo-RS config = RolloutCorrectionConfig.bypass_ppo_clip_k3_rs() # PPO-clip + K3-RS # === Bypass PG mode (2 policies, no PPO clipping) - fast === # IS weights computed on-the-fly as π_θ / π_rollout config = RolloutCorrectionConfig.bypass_pg_is() # Seq-TIS + PG -config = RolloutCorrectionConfig.bypass_pg_geo_rs() # Geo-RS + PG (ratio) +config = RolloutCorrectionConfig.bypass_pg_geo_rs() # Geo-RS + PG config = RolloutCorrectionConfig.bypass_pg_geo_rs_token_tis() # Geo-RS + Token-TIS + PG # === Other === @@ -187,8 +194,8 @@ actor_rollout_ref: ### **Configuration Files** -- `verl/trainer/config/algorithm.py` - Rollout Correction parameters in `AlgoConfig` -- `verl/workers/config/actor.py` - Rollout Correction parameters in `ActorConfig` +- `verl/trainer/config/algorithm.py` - Rollout Correction parameters in `RolloutCorrectionConfig` +- `verl/workers/config/actor.py` - Rollout Correction parameters in `PolicyLossConfig` - `verl/trainer/config/actor/actor.yaml` - Rollout Correction configuration section - `verl/trainer/config/ppo_trainer.yaml` - Algorithm config with Rollout Correction @@ -218,7 +225,7 @@ Importance sampling weights aggregation level: - `null` = No IS weights computed (metrics-only mode) - `"token"`: Per-token IS weights - **Decoupled mode**: ρ_t = π_old(t)/π_rollout(t) - - **Bypass/Pure IS mode**: ρ*t = π*θ(t)/π_rollout(t) + - **Bypass/Pure IS mode**: ρ_t = π_θ(t)/π_rollout(t) - Independent truncation per token - Typical threshold: 1.5 - 5.0 - `"sequence"`: Per-sequence weight ρ_seq = ∏_t ρ_t @@ -268,8 +275,8 @@ Rejection sampling aggregation modes. Supply a comma-separated string (spaces op Threshold specification for rejection sampling. - Provide **one entry per option**, separated by commas. A single entry is broadcast to every option. -- **Ratio modes (`*k1`)**: Use `"lower_upper"` strings (e.g. `"0.7_1.3"`). Supplying a float implies only the upper bound; the lower bound defaults to its reciprocal. -- **Divergence modes (`*k2`/`*k3`)**: Supply positive upper bounds (float or numeric string). +- **K1 KL modes (`*k1`)**: Use `"lower_upper"` strings (e.g. `"0.7_1.3"`). Supplying a float implies only the upper bound; the lower bound defaults to its reciprocal. +- **K2/K3 KL modes (`*k2`/`*k3`)**: Supply positive upper bounds (float or numeric string). - Set to `null` to disable thresholds entirely (only valid when `rollout_rs` is null). ## Understanding the Framework: Components and Combinations @@ -280,8 +287,8 @@ The rollout correction framework is built from **orthogonal components** that ca 1. **Operating Mode** (Section: [Operation Modes](#operation-modes)) - - **Decoupled**: Three policies (π*rollout, π_old, π*θ) with separate π_old computation - - **Bypass**: Two policies (π*rollout = π_old, π*θ), skips π_old computation + - **Decoupled**: Three policies (π_rollout, π_old, π_θ) with separate π_old computation + - **Bypass**: Two policies (π_rollout = π_old, π_θ), skips π_old computation 2. **Loss Function** (in bypass mode, controlled by `loss_type`) @@ -306,23 +313,23 @@ This section provides detailed guidance on choosing and using the verified prese | Preset Method | Estimator | Mode | IS Level | RS Level | Properties | | ------------------------------------------------------------------------------ | ---------------- | ------------------ | -------- | -------- | --------------------------------------- | -| **Decoupled PPO Mode** (3 policies: π*rollout, π_old, π*θ) | -| `decoupled_token_is()` | Token-TIS | Decoupled | token | - | Per-token IS weights | +| **Decoupled PPO Mode** (3 policies: π_rollout, π_old, π_θ) | +| `decoupled_token_is()` | Token-TIS | Decoupled | token | - | Token-level IS weights | | `decoupled_seq_is()` | Seq-TIS | Decoupled | sequence | - | Sequence-level IS weights | -| `decoupled_seq_is_rs()` | Seq-MIS | Decoupled | sequence | sequence | Sequence IS + sequence RS | -| `decoupled_geo_rs()` | Geo-RS | Decoupled | - | sequence | Geometric RS (ratio mode) | -| `decoupled_geo_rs_token_tis()` | Geo-RS-Token-TIS | Decoupled | token | sequence | Geometric filter + token clipped weight | +| `decoupled_seq_is_rs()` | Seq-MIS | Decoupled | sequence | sequence | Sequence IS + seq_sum_k1 RS | +| `decoupled_geo_rs()` | Geo-RS | Decoupled | - | sequence | Geometric RS (seq_mean_k1) | +| `decoupled_geo_rs_token_tis()` | Geo-RS-Token-TIS | Decoupled | token | sequence | Geometric RS + token IS | | **K3 KL Estimator** (more stable for small KL values) | -| `decoupled_k3_rs()` | K3-RS | Decoupled | - | k3 | K3 rejection, no IS weights | -| `decoupled_k3_rs_token_tis()` | K3-RS-Token-TIS | Decoupled | token | k3 | K3 filter + token clipped weight | +| `decoupled_k3_rs()` | K3-RS | Decoupled | - | sequence | seq_mean_k3 RS | +| `decoupled_k3_rs_token_tis()` | K3-RS-Token-TIS | Decoupled | token | sequence | seq_mean_k3 RS + token IS | | **Bypass Mode (PPO-clip)** (2 policies; ratio handles IS, RS masks outliers) | | `bypass_ppo_clip()` | - | Bypass (PPO-clip) | - | - | PPO-clip only | -| `bypass_ppo_clip_geo_rs()` | Geo-RS | Bypass (PPO-clip) | - | sequence | PPO-clip + Geo-RS (ratio) | -| `bypass_ppo_clip_k3_rs()` | K3-RS | Bypass (PPO-clip) | - | k3 | PPO-clip + K3-RS | +| `bypass_ppo_clip_geo_rs()` | Geo-RS | Bypass (PPO-clip) | - | sequence | PPO-clip + Geo-RS | +| `bypass_ppo_clip_k3_rs()` | K3-RS | Bypass (PPO-clip) | - | sequence | PPO-clip + K3-RS | | **Bypass Mode (REINFORCE)** (2 policies; explicit IS weights, no PPO clipping) | | `bypass_pg_is()` | Seq-TIS | Bypass (REINFORCE) | sequence | - | REINFORCE with explicit IS | -| `bypass_pg_geo_rs()` | Geo-RS | Bypass (REINFORCE) | - | sequence | REINFORCE with Geo-RS (ratio) | -| `bypass_pg_geo_rs_token_tis()` | Geo-RS-Token-TIS | Bypass (REINFORCE) | token | sequence | REINFORCE + Geo filter + token IS | +| `bypass_pg_geo_rs()` | Geo-RS | Bypass (REINFORCE) | - | sequence | REINFORCE with Geo-RS | +| `bypass_pg_geo_rs_token_tis()` | Geo-RS-Token-TIS | Bypass (REINFORCE) | token | sequence | REINFORCE + Geo-RS + token IS | | **Other** | | `disabled()` | - | - | - | - | Metrics only, no correction | @@ -330,15 +337,15 @@ This section provides detailed guidance on choosing and using the verified prese - **Bypass mode** sets π_old = π_rollout and uses `loss_type` to select the loss function: - `"ppo_clip"` (default): PPO clipped objective where ratio = π_θ/π_rollout already handles IS - - `"reinforce"`: REINFORCE with explicit IS weights as π_θ / π_rollout + - `"reinforce"`: REINFORCE with explicit IS weights as π_θ/π_rollout - Both loss types benefit from rejection sampling (RS) which masks out-of-distribution samples. -- Estimators (Token-TIS, Seq-TIS, Seq-MIS, Geo-RS) are compatible with Decoupled and Bypass modes. +- All estimators (Token-TIS, Seq-TIS, Seq-MIS, Geo-RS, ...) are compatible with Decoupled and Bypass modes. #### Other Supported Combinations (Manual Configuration Required) **Other supported combinations without preset methods:** -- Token IS + Token RS: Token-level IS weights + token-level RS mask +- Token IS + Token RS: Token-level IS weights + Token-level RS mask - Pure token RS: Token-level RS only, no IS weights - Pure sequence RS: Sequence-level RS only, no IS weights @@ -346,7 +353,7 @@ See [detailed configuration examples below](#additional-useful-configurations-no **Key properties:** -- Any aggregation level (token/sequence/geometric) works in either decoupled or bypass mode +- Any aggregation level (token/sequence) works in either decoupled or bypass mode - All combinations are fully supported by the implementation - Rejection sampling is independent of IS weighting - Pure RS (`bypass_pg_rs`) uses bypass + geometric RS with `loss_type="reinforce"` (no IS weights) @@ -467,7 +474,7 @@ algorithm: - **Seq-TIS (clipping only)**: Maximizes information efficiency; extracts signal from all samples. Use when data is clean and mismatch is moderate. - **Seq-MIS (rejection)**: Maximizes safety; acts as a hard trust region filter. Use when mismatch is severe or when high-weight samples are likely garbage rather than signal. -**Theory:** See [rollout_corr_math.md §3.4](rollout_corr_math.md#34-rejection-sampling-rs) +**Theory:** See [rollout_corr_math.md §3.5](rollout_corr_math.md#35-rejection-sampling-rs) --- @@ -481,7 +488,7 @@ config = RolloutCorrectionConfig.bypass_ppo_clip() **Components:** -- **Operating Mode**: Bypass (2 policies: π*rollout = π_old, π*θ) +- **Operating Mode**: Bypass (2 policies: π_rollout = π_old, π_θ) - **Loss**: PPO-clip (IS handled by ratio, no explicit IS weights) - **IS Aggregation**: None (PPO ratio handles it) - **RS**: None @@ -489,12 +496,11 @@ config = RolloutCorrectionConfig.bypass_ppo_clip() **Equivalent YAML:** ```yaml -algorithm: - rollout_correction: - rollout_is: null - rollout_rs: null - bypass_mode: true - loss_type: ppo_clip +rollout_correction: + rollout_is: null + rollout_rs: null + bypass_mode: true + loss_type: ppo_clip ``` **Properties:** @@ -508,6 +514,12 @@ algorithm: - Set `actor_rollout_ref.rollout.calculate_log_probs: true` +**Additional requirements for bypass mode:** + +- Set `actor_rollout_ref.actor.use_rollout_log_probs: true` +- Set `actor_rollout_ref.actor.policy_loss.loss_mode: bypass_mode` +- Set rollout correction config via `actor_rollout_ref.actor.policy_loss.rollout_correction` + **Theory:** See [rollout_corr_math.md §3.1.2](rollout_corr_math.md#312-bypass-mode-two-policies) --- @@ -522,7 +534,7 @@ config = RolloutCorrectionConfig.bypass_pg_is(threshold=2.0) **Components:** -- **Operating Mode**: Bypass (2 policies: π*rollout, π*θ) +- **Operating Mode**: Bypass (2 policies: π_rollout, π_θ) - **Loss**: REINFORCE (policy gradient with explicit IS weights, no PPO clipping) - **IS Aggregation**: Sequence-level - **RS**: None @@ -530,13 +542,12 @@ config = RolloutCorrectionConfig.bypass_pg_is(threshold=2.0) **Equivalent YAML:** ```yaml -algorithm: - rollout_correction: - rollout_is: sequence - rollout_is_threshold: 2.0 - rollout_rs: null - bypass_mode: true - loss_type: reinforce # REINFORCE with explicit IS weights +rollout_correction: + rollout_is: sequence + rollout_is_threshold: 2.0 + rollout_rs: null + bypass_mode: true + loss_type: reinforce # REINFORCE with explicit IS weights ``` **Properties:** @@ -633,7 +644,7 @@ Rejection sampling modifies `response_mask` (NOT weights) through `compute_rollo - Computes safety-bounded ratios independently - Creates binary mask: tokens/sequences outside [lower_threshold, upper_threshold] → 0 (rejected) -- Modified mask used for loss aggregation (rejected samples excluded from training) +- Modified mask used for loss aggregation ## Operation Modes @@ -719,12 +730,11 @@ This workflow uses bypass mode for efficiency. 1. **Start with metrics only** to understand the off-policy gap: ```yaml - algorithm: - rollout_correction: - rollout_is: null - rollout_rs: null - bypass_mode: true # Bypass mode (recommended) - loss_type: ppo_clip # Default: PPO clipped objective + rollout_correction: + rollout_is: null + rollout_rs: null + bypass_mode: true # Bypass mode (recommended) + loss_type: ppo_clip # Default: PPO clipped objective ``` Monitor `rollout_corr/kl`, `rollout_corr/log_ppl_abs_diff`, `rollout_corr/chi2_token` to assess off-policy gap. @@ -732,27 +742,25 @@ This workflow uses bypass mode for efficiency. 2. **Enable rejection sampling** if you see high outlier fractions: ```yaml - algorithm: - rollout_correction: - rollout_is: null - rollout_rs: sequence # or "geometric" for higher sensitivity - rollout_rs_threshold: 2.0 - bypass_mode: true # Bypass mode - loss_type: ppo_clip # or "reinforce" for explicit IS weights + rollout_correction: + rollout_is: null + rollout_rs: sequence # or "geometric" for higher sensitivity + rollout_rs_threshold: 2.0 + bypass_mode: true # Bypass mode + loss_type: ppo_clip # or "reinforce" for explicit IS weights ``` This excludes outliers from training without modifying gradients. 3. **Enable full IS correction** (with REINFORCE loss) once comfortable with metrics: ```yaml - algorithm: - rollout_correction: - rollout_is: sequence # Recommended: unbiased, suitable for most cases - rollout_is_threshold: 2.0 - rollout_rs: sequence # or "geometric" for more aggressive filtering - rollout_rs_threshold: 2.0 - bypass_mode: true # Bypass mode - loss_type: reinforce # REINFORCE with explicit IS weights + rollout_correction: + rollout_is: sequence # Recommended: unbiased, suitable for most cases + rollout_is_threshold: 2.0 + rollout_rs: sequence # or "geometric" for more aggressive filtering + rollout_rs_threshold: 2.0 + bypass_mode: true # Bypass mode + loss_type: reinforce # REINFORCE with explicit IS weights ``` **Benefits of bypass mode:** @@ -779,6 +787,12 @@ actor_rollout_ref: calculate_log_probs: true # Required! ``` +### Additional Configurations for Bypass Mode + +- Set `actor_rollout_ref.actor.use_rollout_log_probs: true` +- Set `actor_rollout_ref.actor.policy_loss.loss_mode: bypass_mode` +- Set rollout correction config via `actor_rollout_ref.actor.policy_loss.rollout_correction` + ### Metrics All metrics are prefixed with `rollout_corr/` in logs. For example, `rollout_is_mean` appears as `rollout_corr/rollout_is_mean`. @@ -873,7 +887,7 @@ These metrics cover both: In bypass/pure IS mode, metrics measure the drift between π_θ and π_rollout directly. -- **`training_ppl`**: Perplexity of training reference policy (π*old in decoupled mode, π*θ in bypass/pure IS mode) +- **`training_ppl`**: Perplexity of training reference policy (π_old in decoupled mode, π_θ in bypass/pure IS mode) - **Formula**: `exp(-mean(log_probs))` - Lower values indicate higher model confidence @@ -1133,13 +1147,12 @@ algorithm: ### Example 6: Bypass Mode with REINFORCE ```yaml -algorithm: - rollout_correction: - rollout_is: sequence # Explicit IS correction in loss - rollout_is_threshold: 2.0 - rollout_rs: null # Optional: can add rejection sampling - bypass_mode: true - loss_type: reinforce # REINFORCE with explicit IS weights +rollout_correction: + rollout_is: sequence # Explicit IS correction in loss + rollout_is_threshold: 2.0 + rollout_rs: null # Optional: can add rejection sampling + bypass_mode: true + loss_type: reinforce # REINFORCE with explicit IS weights ``` **No PPO clipping, pure policy gradient with IS correction** @@ -1147,14 +1160,13 @@ algorithm: ### Example 7: Bypass Mode with PPO-clip + Rejection Sampling ```yaml -algorithm: - rollout_correction: - rollout_is: sequence # Computed for metrics - rollout_is_threshold: 2.0 - rollout_rs: seq_max_k2 # Sequence max χ²/2 guard - rollout_rs_threshold: 2.5 - bypass_mode: true - loss_type: ppo_clip # PPO clipped objective (IS handled by ratio) +rollout_correction: + rollout_is: sequence # Computed for metrics + rollout_is_threshold: 2.0 + rollout_rs: seq_max_k2 # Sequence max χ²/2 guard + rollout_rs_threshold: 2.5 + bypass_mode: true + loss_type: ppo_clip # PPO clipped objective (IS handled by ratio) ``` **PPO clipping with rejection sampling. IS handled by PPO ratio (no explicit IS weights).** @@ -1281,7 +1293,7 @@ Run the test suite to verify everything works: ```bash # Basic unit tests -python test_rollout_corr.py +python tests/trainer/ppo/test_rollout_corr.py # Integration tests (if pytest is available) pytest tests/trainer/ppo/test_rollout_corr_integration.py -v @@ -1309,5 +1321,4 @@ Rollout Correction provides a unified framework for handling general off-policy ## References - **[Mathematical Formulations](rollout_corr_math.md)** - Detailed mathematical theory and derivations for all rollout correction methods -- [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://richardli.xyz/rl-collapse) (see Blog Series above for parts 1-3) - [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) diff --git a/docs/algo/rollout_corr_math.md b/docs/algo/rollout_corr_math.md index b0b0c13a29c..7118c1ddb49 100644 --- a/docs/algo/rollout_corr_math.md +++ b/docs/algo/rollout_corr_math.md @@ -23,6 +23,14 @@ month = sep, url = {https://richardli.xyz/rl-collapse} } + + +@article{li2025trust, + title={Trust Region Masking for Long-Horizon LLM Reinforcement Learning}, + author={Li, Yingru and Liu, Jiacai and Xu, Jiawei and Tong, Yuxuan and Li, Ziniu and Liu, Qian and Wang, Baoxiang}, + journal={arXiv preprint arXiv:2512.23075}, + year={2025} +} ``` ### Blog Series @@ -31,6 +39,7 @@ - [Part 1: Why Mismatch Breaks LLM-RL](https://richardli.xyz/rl-collapse-1) (analytical framework using TV distance for bias and χ²-divergence for variance) - [Part 2: The Gradient Estimator Trials](https://richardli.xyz/rl-collapse-2) (token-level vs sequence-level correction bias-variance tradeoff) - [Part 3: When Math Meets Reality—Toxic Tails and Length Traps](https://richardli.xyz/rl-collapse-3) (why rejection over clipping, and geometric-level RS) +- Latest Paper: https://arxiv.org/abs/2512.23075 ## Abstract @@ -400,7 +409,7 @@ rollout_rs = "token_k1" # Optional: rejection sampling (ratio bounds) - Lower variance than sequence-level (product of ratios bounded individually) - **Bias-variance tradeoff**: Token-level correction has $O(T^2 \Delta_{\max})$ bias where $T$ is sequence length and $\Delta_{\max}$ is maximum per-token policy divergence. This bias becomes significant when the rollout policy deviates substantially from the training policy. Sequence-level correction is unbiased but has higher variance. - Typical threshold: 1.5 - 5.0 -- Optional batch normalization (§3.6): Normalizes over all token weights to ensure $\mathbb{E}[\tilde{w}_t] = 1$ (reduces variance) +- Optional batch normalization [§3.4](rollout_corr_math.md#34-batch-normalization): Normalizes over all token weights to ensure $\mathbb{E}[\tilde{w}_t] = 1$ (reduces variance) - **When to use**: Token-level works well when rollout policy stays within the trust region of training policy. When mismatch is significant, the bias becomes intolerable and sequence-level correction is preferred. **Loss function (REINFORCE + Token IS):** @@ -429,7 +438,7 @@ rollout_rs = "seq_sum_k1" # Optional: rejection sampling - Multiplicative aggregation across sequence - More sensitive to outliers than token-level - Typical threshold: 2.0 - 10.0 -- Optional batch normalization (§3.6): Normalizes over sequence means (one weight per sequence) +- Optional batch normalization [§3.4](rollout_corr_math.md#34-batch-normalization): Normalizes over sequence means (one weight per sequence) **Terminology Note:** - **Seq-TIS (Sequence-Level Truncated IS)**: Clips the sequence ratio $\rho(\tau) \to \min(\rho(\tau), C)$. Maximizes information efficiency by extracting signal from all samples. Best for clean data with moderate mismatch. @@ -590,7 +599,7 @@ $$ \hat{g}_{\text{k3-rs-token-tis}}(y) = \underbrace{\mathbb{I}\left( K3_{\text{seq}} \le C_{\text{k3}} \right)}_{\text{K3 Filter}} \cdot \prod_t \min(\rho_t, C) \cdot f(y) $$ -This is implemented by combining `rollout_rs="k3"` with `rollout_is="token"`. +This is implemented by combining `rollout_rs="seq_mean_k3"` with `rollout_is="token"`. --- @@ -692,7 +701,7 @@ rollout_rs_threshold = 0.01 | `decoupled_token_is()` | Token-TIS | Decoupled PPO | Per-token IS weights | | `decoupled_seq_is()` | Seq-TIS | Decoupled PPO | Sequence-level IS weights | | `decoupled_seq_is_rs()` | Seq-MIS | Decoupled PPO | Sequence IS + sequence RS | -| `decoupled_geo_rs()` | Geo-RS | Decoupled PPO | Geometric RS + seq\_max\_k2 guard | +| `decoupled_geo_rs()` | Geo-RS | Decoupled PPO | Geometric RS | | `decoupled_geo_rs_token_tis()` | Geo-RS-Token-TIS | Decoupled PPO | Geometric filter + token IS | | **K3 KL Estimator** (more stable for small KL values) | | `decoupled_k3_rs()` | K3-RS | Decoupled PPO | K3 rejection, no IS weights | @@ -950,5 +959,3 @@ These estimators define **how IS weights and rejection masks are computed**. The - **Schulman, J., Wolski, F., Dhariwal, P., Radford, A., & Klimov, O. (2017).** "Proximal policy optimization algorithms." *arXiv preprint arXiv:1707.06347.* https://arxiv.org/abs/1707.06347 - **Hilton, J., Cobbe, K., & Schulman, J. (2021).** "Batch size-invariance for policy optimization." *arXiv preprint arXiv:2110.00641.* https://arxiv.org/abs/2110.00641 - Introduced decoupled PPO: separating proximal policy (for controlling policy update size) from behavior policy (for off-policy correction) to achieve batch size invariance -- **Liu, J., Li, Y., et al. (2025).** "When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch" - - Blog post: https://richardli.xyz/rl-collapse (see Blog Series above for parts 1-3) diff --git a/docs/ascend_tutorial/ascend_ci_guide_zh.rst b/docs/ascend_tutorial/ascend_ci_guide_zh.rst index 8dba3f12a86..5e0e8e91a2d 100644 --- a/docs/ascend_tutorial/ascend_ci_guide_zh.rst +++ b/docs/ascend_tutorial/ascend_ci_guide_zh.rst @@ -12,10 +12,10 @@ NPU 相关的工作流主要包括: * 以 ``_ascend.yml`` 结尾的文件:运行针对 Ascend NPU 的端到端测试或专项测试。 添加新用例指南 -============= +----------------------------------- 1. 数据集与权重 -------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 流水机器上的权重与绝对路径: +---------------------------------------+-------------------------------------------------------------------+ @@ -56,7 +56,7 @@ NPU 相关的工作流主要包括: 2. 工作流 YAML 模板 -------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 如需新增一个工作流,可参考以下模板创建 ``.github/workflows/your_yml_ascend.yml`` 文件。 @@ -140,31 +140,31 @@ NPU 相关的工作流主要包括: 3. 添加单元测试 ---------------- +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 步骤: -1. 在 ``tests/`` 目录下创建或修改单元测试文件(例如 ``test_xxx.py``)。 -2. 若测试文件路径未被 ``npu_unit_test.yml`` 中的 ``--ignore-glob`` 规则排除,则会在以下命令中自动执行: +(1) 在 ``tests/`` 目录下创建或修改单元测试文件(例如 ``test_xxx.py``)。 +(2) 若测试文件路径未被 ``npu_unit_test.yml`` 中的 ``--ignore-glob`` 规则排除,则会在以下命令中自动执行: .. code-block:: yaml pytest -s -x --ignore-glob="xxx" --ignore-glob="xxx" tests/ -3. 若测试路径在 ``--ignore-glob`` 排除范围内,需在 ``npu_unit_test.yml`` 中新增一个 step 来显式运行该测试。 -4. 如新增一批相关用例,建议单独创建专门的工作流文件以保持清晰。 +(3) 若测试路径在 ``--ignore-glob`` 排除范围内,需在 ``npu_unit_test.yml`` 中新增一个 step 来显式运行该测试。 +(4) 如新增一批相关用例,建议单独创建专门的工作流文件以保持清晰。 -3. 添加端到端测试脚本 ---------------------- +4. 添加端到端测试脚本 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 步骤: -1. 在 ``tests/special_npu/`` 目录下创建端到端测试脚本。 -2. 在 ``.github/workflows/`` 目录中找到功能最接近的以 ``_ascend.yml`` 结尾的工作流文件,在其中添加一个 step 调用该脚本。 -3. 若测试场景独立或较复杂,可考虑单独创建新的工作流文件。 +(1) 在 ``tests/special_npu/`` 目录下创建端到端测试脚本。 +(2) 在 ``.github/workflows/`` 目录中找到功能最接近的以 ``_ascend.yml`` 结尾的工作流文件,在其中添加一个 step 调用该脚本。 +(3) 若测试场景独立或较复杂,可考虑单独创建新的工作流文件。 -4. 测试策略建议 ---------------- +5. 测试策略建议 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * **单元测试**:覆盖核心函数、类与方法,确保逻辑正确。 * **集成/端到端测试**:覆盖典型训练、推理 pipeline,验证多模块协同与硬件适配。 diff --git a/docs/ascend_tutorial/ascend_consistency.rst b/docs/ascend_tutorial/ascend_consistency.rst index 20aab3c7057..72c79a37318 100644 --- a/docs/ascend_tutorial/ascend_consistency.rst +++ b/docs/ascend_tutorial/ascend_consistency.rst @@ -1,4 +1,4 @@ -Align the Inference results of the verl and vLLM frameworks on Ascend devices(zh) +推理一致性指导 ==================================== 在昇腾设备上对齐verl和vLLM两个框架下的推理结果。 diff --git a/docs/ascend_tutorial/ascend_profiling_en.rst b/docs/ascend_tutorial/ascend_profiling_en.rst index 4f43544f8f5..f201acc2e37 100644 --- a/docs/ascend_tutorial/ascend_profiling_en.rst +++ b/docs/ascend_tutorial/ascend_profiling_en.rst @@ -1,4 +1,4 @@ -Performance data collection based on FSDP or MindSpeed(Megatron) on Ascend devices(en) +Profiling Data Collection Guide ========================================================================================== Last updated: 12/20/2025. diff --git a/docs/ascend_tutorial/ascend_profiling_zh.rst b/docs/ascend_tutorial/ascend_profiling_zh.rst index 079a8d060e8..8d28be0d59d 100644 --- a/docs/ascend_tutorial/ascend_profiling_zh.rst +++ b/docs/ascend_tutorial/ascend_profiling_zh.rst @@ -1,9 +1,6 @@ -Performance data collection based on FSDP or MindSpeed(Megatron) on Ascend devices(zh) +Profiling采集指导 ================================================================================== -在昇腾设备上基于 FSDP 或 MindSpeed (Megatron) 后端进行性能数据采集 ----------------------------------------------------------------- - Last updated: 12/20/2025. 这是一份在昇腾设备上基于FSDP或MindSpeed(Megatron)后端,使用GRPO或DAPO算法进行数据采集的教程。 diff --git a/docs/ascend_tutorial/ascend_quick_start.rst b/docs/ascend_tutorial/ascend_quick_start.rst index 1fa607befe4..753b78145aa 100644 --- a/docs/ascend_tutorial/ascend_quick_start.rst +++ b/docs/ascend_tutorial/ascend_quick_start.rst @@ -1,7 +1,7 @@ Ascend Quickstart =================================== -Last updated: 12/11/2025. +Last updated: 2/13/2026. 我们在 verl 上增加对华为昇腾设备的支持。 @@ -27,11 +27,11 @@ Atlas 800T A3 ----------------------------------- -DockerFile镜像构建 & 使用 +DockerFile镜像构建 & 获取 & 使用 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -如需要通过 DockerFile 构建镜像,或希望使用基于 verl 构建的镜像,请参考 `文档 `_ 。 - +如需要通过 DockerFile 构建镜像,或希望使用基于 verl 构建的镜像,请参考 `文档 `_ +如果想直接获取镜像,请前往`quay.io/ascend/verl `_ 进行获取,镜像中已包含基础环境和依赖软件包。 安装基础环境 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -67,19 +67,20 @@ DockerFile镜像构建 & 使用 +---------------+----------------------+ | triton-ascend | == 3.2.0rc4 | +---------------+----------------------+ - | transformers | latest release | + | transformers | == 4.57.6 | +---------------+----------------------+ - + + tips: verl is not support transformers 5.0.0 or higher 安装指令: - + .. code-block:: bash - + # 安装torchvision,版本需要和torch匹配 pip install torchvision==0.22.1 - + # 清理环境上可能存在的历史triton/triton-ascend软件包残留 pip uninstall -y triton triton-ascend - + # 安装triton-ascend,不需要单独安装triton pip install triton-ascend==3.2.0rc4 @@ -115,30 +116,30 @@ DockerFile镜像构建 & 使用 MindSpeed 源码安装指令: .. code-block:: bash - + # 下载 MindSpeed,切换到指定commit-id,并下载 Megatron-LM git clone https://gitcode.com/Ascend/MindSpeed.git cd MindSpeed && git checkout f2b0977e && cd .. git clone --depth 1 --branch core_v0.12.1 https://github.com/NVIDIA/Megatron-LM.git - + # 安装 MindSpeed & Megatron pip install -e MindSpeed - + # 将 Megatron-LM 源码路径配置到 PYTHONPATH 环境变量中 export PYTHONPATH=$PYTHONPATH:"$(pwd)/Megatron-LM" - + # (可选)如希望 shell 关闭,或系统重启后,PYTHONPATH 环境变量仍然生效,建议将它添加到 .bashrc 配置文件中 echo "export PYTHONPATH=$PYTHONPATH:\"$(pwd)/Megatron-LM\"" >> ~/.bashrc - + # 安装 mbridge pip install mbridge MindSpeed 对应 Megatron-LM 后端使用场景,使用方式如下: 1. 使能 verl worker 模型 ``strategy`` 配置为 ``megatron`` ,例如 ``actor_rollout_ref.actor.strategy=megatron``。 - + 2. MindSpeed 自定义入参可通过 ``override_transformer_config`` 参数传入,例如对 actor 模型开启 FA 特性可使用 ``+actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True``。 - + 3. 更多特性信息可参考 `MindSpeed & verl 文档 `_ 。 @@ -163,7 +164,7 @@ verl 中昇腾暂不支持生态库如下: +---------------+----------------+ | liger-kernel | not supported | +---------------+----------------+ - + 1. 不支持通过 flash_attn 使能 flash attention 加速,支持通过 transformers 使用。 2. 不支持 liger-kernel 使能。 @@ -175,17 +176,17 @@ verl 中昇腾暂不支持生态库如下: 1.下载数据集并将数据集预处理为parquet格式,以便包含计算RL奖励所需的必要字段 .. code-block:: bash - + python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k 2.执行训练 .. code-block:: bash - + set -x - + export VLLM_ATTENTION_BACKEND=XFORMERS - + python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/gsm8k/train.parquet \ diff --git a/docs/ascend_tutorial/ascend_sglang_quick_start.rst b/docs/ascend_tutorial/ascend_sglang_quick_start.rst index 8b1661cbbe4..44a7f0c6613 100644 --- a/docs/ascend_tutorial/ascend_sglang_quick_start.rst +++ b/docs/ascend_tutorial/ascend_sglang_quick_start.rst @@ -76,7 +76,8 @@ Atlas 800T A3 git clone https://github.com/volcengine/verl.git # Make sure you have activated verl conda env # NPU_DEVICE=A3 or A2 depends on your device - NPU_DEVICE=A3 bash verl/scripts/install_sglang_mcore_npu.sh + # USE_MEGATRON=1 if you need to install megatron backend + NPU_DEVICE=A3 USE_MEGATRON=1 bash verl/scripts/install_sglang_mcore_npu.sh **4. 安装verl** diff --git a/docs/ascend_tutorial/dockerfile_build_guidance.rst b/docs/ascend_tutorial/dockerfile_build_guidance.rst index e9624d7a6d5..eab82b7365c 100644 --- a/docs/ascend_tutorial/dockerfile_build_guidance.rst +++ b/docs/ascend_tutorial/dockerfile_build_guidance.rst @@ -5,6 +5,16 @@ Last updated: 12/4/2025. 我们在verl上增加对华为昇腾镜像构建的支持。 +镜像获取 & 公开镜像地址 +-------------------- + +昇腾在 `quay.io/ascend/verl `_ 中托管每日构建的 A2/A3 镜像,基于上述 Dockerfile 构建。 + +每日构建镜像名格式:verl-{CANN版本}-{NPU设备类型}-{操作系统版本}-{python版本}-latest + +verl release版本镜像名格式:verl-{CANN版本}-{NPU设备类型}-{操作系统版本}-{python版本}-{verl release版本号} + + 镜像硬件支持 ----------------------------------- @@ -66,16 +76,8 @@ A3 8.3.RC1 SGLang `Dockerfile.ascend.sglang_8.3.rc # vLLM docker build -f Dockerfile.ascend_8.3.rc1_a2 -t verl-ascend:8.3.rc1-a2 . # SGLang - docker build -f Dockerfile.ascend_8.3.rc1_a2 -t verl-ascend-sglang:8.3.rc1-a2 . - -公开镜像地址 --------------------- - -昇腾在 `quay.io/ascend/verl `_ 中托管每日构建的 A2/A3 镜像,基于上述 Dockerfile 构建。 - -每日构建镜像名格式:verl-{CANN版本}-{NPU设备类型}-{操作系统版本}-{python版本}-latest + docker build -f Dockerfile.ascend.sglang_8.3.rc1_a2 -t verl-ascend-sglang:8.3.rc1-a2 . -verl release版本镜像名格式:verl-{CANN版本}-{NPU设备类型}-{操作系统版本}-{python版本}-{verl release版本号} 声明 -------------------- diff --git a/docs/ascend_tutorial/examples/ascend_performance_analysis_guide.md b/docs/ascend_tutorial/examples/ascend_performance_analysis_guide.md new file mode 100644 index 00000000000..5921e7de399 --- /dev/null +++ b/docs/ascend_tutorial/examples/ascend_performance_analysis_guide.md @@ -0,0 +1,169 @@ +# Ascend Performance Analysis Guide + +Last updated: 02/24/2026. + +## 背景介绍 + +随着DeepSeek-R1的发布,大模型强化学习(RL)训练受到广泛关注。在昇腾NPU环境下,verl框架已积累了丰富的性能调优经验。本文系统总结了包括性能数据采集与分析在内的方法论,旨在帮助开发者更高效地运用MindStudio工具链,实现强化学习场景下的性能优化。 + +### 强化学习计算流程概述 + +1. **Rollout**:策略(actor)模型基于输入的prompt序列,推理生成回答(response序列) +2. **ref logprob**:基于prompt和生成的response,reference模型计算ref logprob用于KL散度计算 +3. **logprob**:基于prompt和生成的response,actor模型计算logprob用于重要性采样 +4. **reward**:基于prompt和生成的response,奖励模型评估奖励值R_N。 +5. **update**:基于计算得到的R_N、ref logprob、logprob计算优化函数和策略梯度,对actor模型进行更新 + +![rl_data_stream](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/rl_data_stream.png) + +## profilling工具使能 + +### 使能方法 + +使能和配置教程可参考:[verl/docs/ascend_tutorial/ascend_profiling_zh.rst at main · verl-project/verl](https://github.com/verl-project/verl/raw/main/docs/ascend_tutorial/ascend_profiling_zh.rst) + +## 性能分析方法论 + +### 整体性能概览分析 + +#### 1. 长耗时任务与资源空泡分析 + +- **操作**:使用MindStudio Insight加载profiling数据,自动识别不同计算阶段,通过RL页签流水图定位长耗时任务与NPU资源空泡 +- **价值**:快速掌握不同阶段耗时占比 +- **效果展示**: + +![Bubble_analysis](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/Bubble_analysis.png) + +#### 2. 负载均衡分析 + +- **操作**:通过MindStudio Insight直接查看MSTX打点数据,观察Rollout阶段不同DP Rank的负载均衡情况 +- **价值**:快速识别负载不均问题 + +- **效果展示:** + +![Load_Balancing_Analysis](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/Load_Balancing_Analysis.gif) + +#### 3. 集群整体性能分析 + +- **操作**:结合MSTT的rl_analysis功能,生成集群Timeline缩略图,观察各阶段整体耗时 +- **价值**:宏观掌握集群性能瓶颈 +- **操作指南**:[rl_analysis使用文档](https://gitcode.com/Ascend/mstt/raw/pre-research/profiler/msprof_analyze/docs/features/rl_analysis.md) +- **效果展示**: + +![Cluster%20Performance%20Analysis](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/Cluster%20Performance%20Analysis.png) + +### 细粒度分析 + +#### 性能分析 + +- **操作**:可通过 MindStudio Insight Windows 或 Linux 版本加载 Profiling 数据 + +- **价值**:MindStudio Insight 支持分析任务调度效率、算子执行性能、计算资源利用率、集合通信性能等。其 Timeline 视图具备任务拆解与 Overlap 分析功能(**为 MindStudio 独有核心特性,在 NV 及其他竞品中不具备,是 AI 调优的必备工具**),并支持鼠标交互式分析。 + +- **效果展示**: + +![performance%20analysis](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/performance%20analysis.png) + +#### 内存分析 + +##### **通过 Profiling 结合调用栈分析系统内存变化** + +- **操作**:采集数据时开启调用栈和内存视图功能。 +- **价值**:观察框架、CANN内存申请释放情况,可结合调用栈跟踪到前端python代码。 +- **效果展示**:结合调用栈进行内存变化分析。效果如下所示: + +![in-memory%20analytics](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/in-memory%20analytics.gif) + +##### **使用 msleaks 工具进行深层次内存分析** + +- **操作步骤**:参考 [msleaks 工具使用指南](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1alpha003/devaids/msleaks/atlas_msleaks_0001.html)。 +- **价值**:可以查看框架内存申请总量折线图/内存块图,并直接对应调用栈,可深层次分析框架内存使用情况。 +- **效果展示**: + +![msleaks](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/msleaks.gif) + +## 性能分析案例 + +要做具体的性能分析,profiling要开启**level1**,否则算子的关键信息会缺失。 + +### 1.host bound诊断 + +host bound是指CPU任务量综合大于NPU,导致NPU执行出现空泡的现象。可以通过看Host2Device的同步连线来判断,如果连线都是歪的,那证明这里的set信号早于wait信号,NPU一ready就执行了,那也是device bound: + +![host_bound_1](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/host_bound_1.png) + +如果确诊为host bound,那么我们可以打开CPU侧,找出各算子的下发耗时。注意找的时候需要找出所有CPU耗时的累加值,而不能找单层,因为首次调用的耗时是很长的。例如下图的GmmSwigluQuant,CPU上首次调用需要1ms,后续每次只需要200us。 + +![host_bound_2](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/host_bound_2.png) + +此时有的算子在负重前行,有的算子拖了后腿,后者多于了前者。我们优先**找出来host耗时大于device的top算子,这些算子是拖后腿的**,可以交予算子团队重点分析。 + +### 2.组网合理性分析 + +有的时候,模型组网没有按照最高效的方式来,这一点在profiling中是非常易于识别的,下面会介绍一下分析思路并给出例子。 + +通常来讲,LLM中的大的热点算子是Attention和FFN中的矩阵乘计算,二者加起来在prefill下可能达到计算耗时的70%+,decode下可能达到50%+。如果整体的耗时比例不符合预期,或者profiling中出现了一些新面孔,或者拼接类算子太多了,这都值得我们去分析一下模型组网,是不是使用算子的方式错了?尤其是拼接类算子,是值得我们逐一分析的。 + +对于slice/split/concat这样的拼接类算子,还有transpose/cast这种转换算子,他们的存在往往是前后算子不直接配套造成的。如果前一个算子可以直接对输出做好尾处理,往往可以节省一个算子的启动开销和一次冗余读写。但这样的改变不一定符合算子的基本设计原则。 + +举一个正例,对于某次Matmul的输出shape为[m, n0 + n1],在这后面我们接了两个slice,输入均为这个[m, n0 + n1]的tensor,输出分别为[m, n0]和[m, n1]。第一个优化的思路是将两个slice改为一个split,这样耗时可以基本减半,[m, n0 + n1]的显存也可以尽早释放。进一步优化的思路是将矩阵乘的权重从[k, n0 + n1]分割为[k, n0]和[k, n1],将原来的矩阵乘任务分成两个(前提是这两个的耗时加起来不比之前的劣化太多,分核策略不能出问题),从而彻底消除这个slice/split操作。 + +![network_1](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/network_1.png) + +举一个反例,Rmsnorm(fp16)+Cast(fp16->fp32)+Matmul(fp32),Rmsnorm虽然输入输出都是fp16,但考虑到累加运算的精度,内部是fp32做计算的。如果将Cast融到Rmsnorm内,本就内部使用fp32做计算的Rmsnorm就可以省去一个末尾fp32->fp16的cast,加上我们干掉的Cast,总共节省两个cast的同时避免了一次精度丢失。虽然这样看起来精度性能双收了,但fp16进,fp32出的Rmsnorm是反原则的(核心输入和输出需要是同数据类型),除非我们能在广大开源模型中频繁找到这样的结构,证明它的普适性,否则算子团队是不允许做这样的算子的。 + +![network_2](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/network_2.png) + +### 3.算子性能初诊 + +需要利用`".\ASCEND_PROFILER_OUTPUT\operator_details.csv"`来做分析,从而判断算子识否有性能问题。 + +Profiling工具会统计这些流水线在不同核上的平均繁忙时间(xxx_time),与最慢核的完整kernel耗时(task_duration)做除法,得到流水线利用率(xxx_ratio)。这些流水线之间虽然互有依赖,且搬运类流水线会互抢带宽,但算子只要设计得当,是可以做到互相掩盖的。因此我们可以初步认为,**当算子的执行耗时大到一定程度上,算子应当在某一条流水线上形成bound**,即利用率要高到一定程度。经验上,在单算子耗时达到50μ时,就可以认为算子应当在bound流水线上,达成80%+的占用率了。 + +以下图为例,第一行是一个FA算子,第二行是一个Matmul算子,FA在vec流水线上达到了88.1%的利用率,Matmul算子在mac流水线上达到了89.8%的利用率,他们的性能可以认为是合格的。 + +![Operator%20performance](https://github.com/chengminhua/verl_data/raw/main/MindStudio_Insight_use/Operator%20performance.png) + + + +### 4.亲和shape调整 + +对于一个模型而言,超参是我们控制不了的,但我们可以控制并发度、权重格式、切分策略等因素来迎合算子,使其发挥出最大的性能,这一节主要从算子搬运效率和负载均衡两个方面出发,讨论模型侧值得尝试的调整方向。 + +#### 4.1 搬运效率亲和的shape + +mte2是一个自身效率严重受shape影响的流水线。要想让mte2保证最大搬运效率,我们需要保障如下两个条件至少满足其一: + +**(1)被搬运的矩阵使用nz作为format(最优) +(2)被搬运的矩阵的尾轴512B对齐,且不为16KB的整数倍(近似最优)** + +对于权重矩阵来说,推理阶段尤其是decode,我们通常满足(1),训练阶段我们通常满足(2)。**如果我们做不到(1),我们就要迎合(2)**。典型的手段有: + +1,如果没达成B的矩阵的首轴是亲和的而尾轴不亲和,那么对它做transpose +2,调整TP切分策略,避免出现不亲和的尾轴 + +#### 4.2 负载均衡亲和的shape + +在算子shape不大时,受制于算子语义,我们有可能不能把所有核都利用起来,或者即使开满核,负载均衡却很差。这一小节主要是对decode阶段的小shape做分析。 + +首先,我们明确出当前NPU卡是多少核的,如果不清楚,跑出来的profiling里都是20,40这样的数,就说明是20核,反之是24核。这里我的24核其实是代表了一个cube和两个vector组成的小组,我们可以认为是一个cube作为主核,带了两个vector作为从核。如果一个算子是纯vector算子,那么就不再有组的概念,40或48个vector核会作为主核直接独立去拿逻辑任务。 + +对于LLM中的vector算子,它的一种常见分核策略有可能是分在最高维,也就是batch维,常见于对低维(也叫尾轴)有规约操作的norm类、动态量化类等算子;另一种是整体拍平,允许算子切分的非常细的算子,如elementwse算子。对于第一种,我们就可以在模型侧关注它的负载均衡问题。例如我们打48batch,而硬件却是个40个vector核,那这40个核会循环2次,第二次有多数的核会无事可做,这个batch数就可以认为是不友好的。如果将batch打到64或80,性能可以预见会是无损的。同样的情况下,如果是48核的卡,那我们可以认为这就是个非常友好的batch数。 + +对于cube类算子,它常见的分核策略是以base快去切分M和N(K轴是累加轴,对它分核会引入确定性问题)。最常见的分块是baseM=128,baseN=256。在decode阶段,我们的耗时基本可以看做都是在搬权重,这是因为激活的M极小,M方向大概率只分了一块,那么右矩阵就只需要搬一次。所以我们在M≤128的范围内可以尽情提高M,对性能都基本是无损的,如果M大于128,可以认为(128, 256]是下一个性能分档。 +除了M外,N轴切分的任务也影响算子亲和性,以deepseekR1中的MLA预处理为例,它会使用同一个激活(shape为[batch_size, 7168])与两个权重做矩阵乘(shape为[7168, 1536]和[7168, 576])。在batch_size打不大的情况下,即使baseN缩短为128,N轴都不能用满核数,所以此时这两个矩阵乘各自的耗时,会约等于将他们权重N轴拼起来乘(shape为[7168, 2112])的矩阵乘的耗时。如果仅考虑模型竞争力,我们更希望对这两个权重做合并,否则两个小的矩阵乘带宽利用率都会非常差。 + +对于Attention算子,它常见的分核策略是q_seqlen、batch_size和kv_headnum。增量阶段q_seqlen会以MTP和GQA倍数做合并,但是通常也不会大过128,划分不出第二个任务,那么并行度基本就是batch_size * kv_headnum。 + +总的来说,我们可以依据shape信息和算子类别,对算子是否有负载均衡问题作出识别,从而对我们切分策略选择,最高吞吐量的batch策略作出预判。 + + + + + + + + + + + diff --git a/docs/ascend_tutorial/examples/ascend_retool_best_pratice.rst b/docs/ascend_tutorial/examples/ascend_retool_best_pratice.rst index 6002b32bf43..4d003fc3370 100644 --- a/docs/ascend_tutorial/examples/ascend_retool_best_pratice.rst +++ b/docs/ascend_tutorial/examples/ascend_retool_best_pratice.rst @@ -1,7 +1,7 @@ Ascend Retool Best Practice =================================== -Last updated: 02/10/2026. +Last updated: 03/01/2026. 引言 ---------------------------------- @@ -100,6 +100,7 @@ https://github.com/bytedance/SandboxFusion .. code-block:: bash + cd SandboxFusion conda create -n sandbox -y python=3.11 conda activate sandbox pip install poetry @@ -135,7 +136,7 @@ https://github.com/bytedance/SandboxFusion dapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k aime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024 #aime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025 - model_path=$DATA_ROOT/dataset/checkpoint/multiturn-sft-qwen-2.5-7b-instruct/ global_step_372/huggingface + model_path=$DATA_ROOT/dataset/checkpoint/multiturn-sft-qwen-2.5-7b-instruct/global_step_372/huggingface train_files="['$dapo_math_17k']" test_files="['$aime_2024']" diff --git a/docs/ascend_tutorial/examples/ascend_sglang_best_practices.rst b/docs/ascend_tutorial/examples/ascend_sglang_best_practices.rst index a2bcaae9958..bb6259a9cd1 100644 --- a/docs/ascend_tutorial/examples/ascend_sglang_best_practices.rst +++ b/docs/ascend_tutorial/examples/ascend_sglang_best_practices.rst @@ -43,18 +43,11 @@ SGLang 是当前主流的高性能开源推理引擎, 昇腾已经全面原生 ^^^^^^^^^^^ **下载模型权重** ---local-dir: 模型保存路径 - -.. code-block:: bash - - export HF_ENDPOINT=https://hf-mirror.com - hf download --resume-download Qwen/Qwen3-30B-A3B --local-dir /path/to/local_dir +Qwen3-30B: https://huggingface.co/Qwen/Qwen3-30B-A3B **下载数据集** -.. code-block:: bash - - git clone https://www.modelscope.cn/datasets/AI-ModelScope/DAPO-Math-17k.git +DAPO-Math-17k: https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k **HuggingFace To Megatron权重转换(可选)** diff --git a/docs/ascend_tutorial/examples/gspo_optimization_practice.md b/docs/ascend_tutorial/examples/gspo_optimization_practice.md index e943fcdbfff..92952141a91 100644 --- a/docs/ascend_tutorial/examples/gspo_optimization_practice.md +++ b/docs/ascend_tutorial/examples/gspo_optimization_practice.md @@ -1,233 +1,415 @@ -## NPU Qwen3-32B GSPO Optimization Practice - -Last updated: 01/27/2026. - -本文章对应脚本地址:[qwen3_32b_gspo_npu](https://github.com/volcengine/verl/blob/main/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh) - -### 算法适配 - -GSPO通过将优化颗粒度从**token级**提升到**sequence级**,规避了GRPO会遇到的**方差急剧增大**导致训练不稳定的情况,增加了训练的稳定性,同时该算法也在一定程度上提升了算法的收敛速度。 - -想要成功在verl仓库中成功调用到GSPO算法,需要进行如下的必要配置 - -~~~python -# 核心算法配置 -algorithm.adv_estimator=grpo \ # 使用GRPO优势估计器 -algorithm.use_kl_in_reward=False \ # 不在奖励中添加KL惩罚 -# GSPO策略损失模式 -actor_rollout_ref.actor.policy_loss.loss_mode=gspo \ # 启用GSPO策略损失 -# 极小裁剪范围(GSPO特色) -actor_rollout_ref.actor.clip_ratio_low=0.0003 \ # 裁剪下界,论文推荐值 -actor_rollout_ref.actor.clip_ratio_high=0.0004 \ # 裁剪上界,论文推荐值 -# KL配置(GSPO不使用KL loss) -actor_rollout_ref.actor.use_kl_loss=False \ # 禁用KL损失 -actor_rollout_ref.actor.kl_loss_coef=0.0 \ # KL损失系数设为0 -# 序列级损失聚合模式(GSPO核心) -actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \ # 序列级平均,GSPO论文推荐 -# 批次配置 -actor_rollout_ref.rollout.n=16 \ # 每个prompt生成16个响应(组采样) -~~~ - -一般选择入口函数为`verl.trainer.main_ppo` - -### 性能调优 - -优化从训练、推理、调度和其他四个方面入手。 - -#### 训练 - -##### 动态bsz - -~~~bash -actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) -infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) -~~~ - -**这个优化点主要调整上面这两个参数,不过需要注意这两个参数调整的太大会导致OOM** - -**主要调整**`actor_ppo_max_token_len`,调大了会降低训练的耗时,调整`infer_ppo_max_token_len`没有明显的收益,可以不动 - -**这两个参数的作用介绍如下:** - -**这两个参数用于控制动态批处理(dynamic batch size)模式下每个GPU处理的最大token数量** - -- **`actor_ppo_max_token_len`**: Actor模型在PPO更新(前向+反向传播)时每个GPU能处理的最大token数 -- **`infer_ppo_max_token_len`**: 推理阶段(Reference policy和Rollout)计算log概率时每个GPU能处理的最大token数 - -#### 推理 - -##### ACLgraph+FULL_DECODE_ONLY - -推理算子下发方面的优化,平均能有`15%~20%`左右的性能收益。 - -先看单开**ACLgraph**,如下: - -~~~bash -# 开启ACLgraph+FULL_DECODE_ONLY(注意:当设置此参数为False时,TASK_QUEUE_ENABLE必须设置为1,不然会报错) -actor_rollout_ref.rollout.enforce_eager=False -actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_capture_sizes='[8,16,32,64,128]' \ -actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode='FULL_DECODE_ONLY' \ -~~~ - -`FULL_DECODE_ONLY`开启成功后有如下输出: - -![FULL_DECODE_ONLY result](https://github.com/wucong25/verl-data/blob/main/ascend_acl_graph.png) - -**`cudagraph_capture_sizes`参数设置指南** - -cudagraph_capture_sizes设置的值对应的是批大小,这里的批大小不是配置里的DP域对应的那个批次大小,这里是相较于vllm来说的批大小,单位为**token** - -默认生成的算法如下,可做参考 - -![cudagraph_capture_sizes](https://github.com/wucong25/verl-data/blob/main/ascend_set_cudagraph_sizes.png) - -##### 推理后端切换 - -使用方式:`export VLLM_ATTENTION_BACKEND=XFORMERS` - -![VLLM_ATTENTION_BACKEND](https://github.com/wucong25/verl-data/blob/main/ascend_vllm_attn_backend.png) - -注:需要注意某些后端在一些比较老的vllm-ascend版本内并不支持 - -##### 使能vllm v1版本 - -使用方式:`export VLLM_USE_V1=1` - -可以常开,一般都是正收益。 - -#### 调度 - -##### AIV - -打开方式:设置`export HCCL_OP_EXPANSION_MODE="AIV"` - -HCCL_OP_EXPANSION_MODE环境变量用于配置通信算法的编排展开位置,支持如下取值: - -- AI_CPU:代表通信算法的编排展开位置在Device侧的AI CPU计算单元。 -- AIV:代表通信算法的编排展开位置在Device侧的Vector Core计算单元。 -- HOST:代表通信算法的编排展开位置为Host侧CPU,Device侧根据硬件型号自动选择相应的调度器。 -- HOST_TS:代表通信算法的编排展开位置为Host侧CPU,Host向Device的Task Scheduler下发任务,Device的Task Scheduler进行任务调度执行。 - -下面介绍两种展开机制 - -###### HOST展开 - -image-20260113194257095 - -- 软件栈工作在hostcpu,通信算法展开一个个task -- 每个task调用runtime接口,下发到device的rtsqueue -- STARS从rstqueue上顺序拿取task -- 根据task类型分别调用掉SDMA和RDMA引擎。 - **单算子瓶颈**:hostbound 每个task提交是2~5us,一个通信算子有几百个task,单算子场景不会在device上缓存,下发一个执行一个 - -###### AICpu机制展开 - -image-20260113194333218 - -- host侧不下发一个个task,把通信算子作为一个个kernel,放在通信算子kernel的队列上去。 -- STARS调度kernel队列流上的kernel,把kernel放到AiCPU上去执行。 -- AICPU调用函数(kernel),用一个线程执行kernel 函数,在函数内把通信task展开,把task放到rstqueue上,STARS调用。 -- 降低host和aicpu交互,由几百次降低为一次。 -- task的提交在AICPU上提交,做了提交的部分合并。 - -##### TASK_QUEUE_ENABLE - -**使用方式:**`export TASK_QUEUE_ENABLE=2` - -TASK_QUEUE_ENABLE,下发优化,图模式设置为1(即开启图模式的时候这个要设置为1),非图模式设置为2 - -示意图: - -![ascend task queue](https://github.com/wucong25/verl-data/blob/main/ascend_task_queue2.png) - -##### 绑核优化 - -**使用方式:**`export CPU_AFFINITY_CONF=1` - -详细设置原理可看:https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0059.html - -#### 其他 - -以下内容汇总了若干全局环境变量的调优配置。由于这些参数在训练阶段与推理阶段往往都能带来正向收益,且目前尚缺乏足够精细的消融实验来严格区分它们各自对训练或推理的贡献占比,故统一归拢在此,供后续持续监控与进一步拆解分析。 - -##### 使能jemalloc - -使用方式(注意需要先安装jemalloc库):`export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2` - -**安装使用教程:**[MindSpeed-RL/docs/install_guide.md · Ascend/MindSpeed-RL - AtomGit | GitCode](https://gitcode.com/Ascend/MindSpeed-RL/blob/master/docs/install_guide.md#高性能内存库-jemalloc-安装) - -##### 多流复用 - -内存方面有优化 - -使能方式:`export MULTI_STREAM_MEMORY_REUSE=1` - -原理介绍:https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0040.html - -##### VLLM_ASCEND_ENABLE_FLASHCOMM - -使用方式:`export VLLM_ASCEND_ENABLE_FLASHCOMM=1` - -启用昇腾 NPU 特有的FLASHCOMM高速通信优化技术 - -地址:https://vllm-ascend.readthedocs.io/zh-cn/latest/user_guide/release_notes.html - -##### VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE - -使用方式:`export VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE=1` - -启用昇腾 NPU针对大模型推理的稠密计算优化 - -地址:https://vllm-ascend.readthedocs.io/zh-cn/latest/user_guide/release_notes.html - -##### VLLM_ASCEND_ENABLE_PREFETCH_MLP - -使用方式:`export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1` - -启用 MLP 层的权重预取机制 - -image-20251124173132677 - -##### verl框架参数设置 - -主要是内存方面的一些设置开关(注意,这个里面的优化都或多或少会导致吞吐量有一定程度的劣化) - -~~~bash -# 梯度检查点 (Gradient Checkpointing) -# 作用: 通过重新计算激活值来节省显存,以计算换内存。在前向传播时不保存中间激活值,反向传播时重新计算,可以显著降低显存占用,允许使用更大的batch size。 -actor_rollout_ref.model.enable_gradient_checkpointing=True - -# 参数卸载 (Parameter Offload) -# 作用: 将模型参数卸载到CPU内存,训练时再加载回GPU。 -actor_rollout_ref.actor.fsdp_config.param_offload=${offload} # True -actor_rollout_ref.ref.fsdp_config.param_offload=${offload} # True - -# 优化器状态卸载 (Optimizer Offload) -# 作用: 将优化器状态(如Adam的动量)卸载到CPU。优化器状态通常占用大量显存(对于Adam,每个参数需要额外8字节),卸载可以节省显存。 -actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} # True - -# 释放推理引擎缓存 (Free Cache Engine) -# 作用: 在训练阶段释放推理引擎的KV cache和权重。这是3D-HybridEngine的核心优化,允许在同一GPU上交替进行推理和训练,显著降低显存需求。 -actor_rollout_ref.rollout.free_cache_engine=True - -# 熵计算优化 -# entropy_checkpointing: 在训练时对熵计算启用重计算,降低显存峰值 -# entropy_from_logits_with_chunking: 分块处理logits张量(如2048 tokens一组),避免一次性加载整个[bsz*seq_len, vocab]张量 -actor_rollout_ref.actor.entropy_checkpointing=True -actor_rollout_ref.ref.entropy_checkpointing=True -actor_rollout_ref.actor.entropy_from_logits_with_chunking=True -actor_rollout_ref.ref.entropy_from_logits_with_chunking=True - -# 推理引擎显存配置 -# gpu_memory_utilization: 控制vLLM使用的GPU显存比例(0.90 = 90%) -# enforce_eager=False: 启用CUDA graphs加速推理,但会占用额外显存 -actor_rollout_ref.rollout.gpu_memory_utilization=0.90 -actor_rollout_ref.rollout.enforce_eager=False -~~~ - -### NPU调优参考文章 - -环境变量相关:[环境变量列表-Ascend Extension for PyTorch6.0.0-昇腾社区](https://www.hiascend.com/document/detail/zh/Pytorch/600/apiref/Envvariables/Envir_001.html) - -社区性能调优教程:[性能调优流程-Ascend Extension for PyTorch6.0.0-昇腾社区](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0001.html) - +# NPU Qwen3-32B GSPO Optimization Practice + +Last updated: 02/26/2026. + +本文章对应脚本地址:[qwen3_32b_gspo_npu](https://github.com/volcengine/verl/blob/main/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh) + +## 算法适配 + +GSPO通过将优化颗粒度从**token级**提升到**sequence级**,规避了GRPO会遇到的**方差急剧增大**导致训练不稳定的情况,增加了训练的稳定性,同时该算法也在一定程度上提升了算法的收敛速度。 + +想要成功在verl仓库中成功调用到GSPO算法,需要进行如下的必要配置 + +~~~python +# 核心算法配置 +algorithm.adv_estimator=grpo \ # 使用GRPO优势估计器 +algorithm.use_kl_in_reward=False \ # 不在奖励中添加KL惩罚 +# GSPO策略损失模式 +actor_rollout_ref.actor.policy_loss.loss_mode=gspo \ # 启用GSPO策略损失 +# 极小裁剪范围(GSPO特色) +actor_rollout_ref.actor.clip_ratio_low=0.0003 \ # 裁剪下界,论文推荐值 +actor_rollout_ref.actor.clip_ratio_high=0.0004 \ # 裁剪上界,论文推荐值 +# KL配置(GSPO不使用KL loss) +actor_rollout_ref.actor.use_kl_loss=False \ # 禁用KL损失 +actor_rollout_ref.actor.kl_loss_coef=0.0 \ # KL损失系数设为0 +# 序列级损失聚合模式(GSPO核心) +actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \ # 序列级平均,GSPO论文推荐 +# 批次配置 +actor_rollout_ref.rollout.n=16 \ # 每个prompt生成16个响应(组采样) +~~~ + +一般选择入口函数为`verl.trainer.main_ppo` + +## 基础环境 + +当前支持Atlas 800T A3 与 Atlas 900 A3 SuperPoD。完成跑完本次最佳实践需要 4台Atlas 800T A3。关键软件版本可以参考:[Ascend Quickstart](https://github.com/volcengine/verl/blob/main/docs/ascend_tutorial/ascend_quick_start.rst) + +### 安装基础环境 + +| software | version | +| ------------ | ---------------------------------------------------------- | +| Python | >= 3.10, <3.12 | +| CANN | == 8.3.RC1 | +| torch | == 2.7.1 | +| torch_npu | == 2.7.1 | +| verl | main分支 commitId=252d76908b903ad8fb6969eb3a5e5f873c95ea2b | +| vllm | v0.11.0 | +| vllm-ascend | v0.11.0-dev | +| transformers | 4.57.3 | + +在本实践中, 我们通过指定 verl 的commit id 以避免引入其他问题 + +~~~bash +cd verl +git checkout 252d76908b903ad8fb6969eb3a5e5f873c95ea2b +# 指定相应的recipe版本 +git submodule update --init --recursive recipe +~~~ + +### 权重获取 + +从Hugging Face库下载对应的模型权重:[Qwen/Qwen3-32B · Hugging Face](https://huggingface.co/Qwen/Qwen3-32B) + +### 数据集准备 + +~~~bash +# 下载math-17k数据集 +git clone https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k + +# 下载AIME_2024测试数据集 +git clone https://huggingface.co/datasets/Maxwell-Jia/AIME_2024 +~~~ + +### jemalloc安装 + +为了确保 Ray 进程能够正常回收内存,需要安装并使能 jemalloc 库进行内存管理。 + +#### Ubuntu 操作系统 + +通过操作系统源安装jemalloc(注意: 要求ubuntu版本>=20.04): + +```shell +sudo apt install libjemalloc2 +``` + +在启动任务前执行如下命令通过环境变量导入jemalloc,需先通过 **find /usr -name libjemalloc.so.2** 确认文件是否存在 : + +```shell +# arm64架构 +export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2 +# x86_64架构 +export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so.2 +``` + +#### OpenEuler 操作系统 + +执行如下命令重操作系统源安装jemalloc + +```shell +yum install jemalloc +``` + +如果上述方法无法正常安装,可以通过源码编译安装 前往jemalloc官网下载最新稳定版本,官网地址:https://github.com/jemalloc/jemalloc/releases/ + +```shell +tar -xvf jemalloc-{version}.tar.bz2 +cd jemalloc-{version} +./configure --prefix=/usr/local +make +make install +``` + +在启动任务前执行如下命令通过环境变量导入jemalloc: + +```shell +#根据实际安装路径设置环境变量,例如安装路径为:/usr/local/lib/libjemalloc.so.2,可通过以下命令来设置环境变量(可通过 find /usr -name libjemalloc.so.2 确认文件是否存在) +export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2 +``` + +### 多机任务拉起 + +针对本实践提供的多机任务,可用下面的脚本拉起 + +~~~bash +pkill -9 python +ray stop --force +rm -rf /tmp/ray + +export RAY_DEDUP_LOGS=0 +export HYDRA_FULL_ERROR=1 +export TASK_QUEUE_ENABLE=1 +export HCCL_EXEC_TIMEOUT=3600 +export HCCL_CONNECT_TIMEOUT=3600 +export HCCL_ASYNC_ERROR_HANDLING=0 +export CPU_AFFINITY_CONF=1 +export VLLM_USE_V1=1 +export VLLM_ATTENTION_BACKEND=XFORMERS +export VLLM_ASCEND_ENABLE_FLASHCOMM=1 +export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1 +export VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE=1 +export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2 + +# 修改为当前需要跑的用例路径 +DEFAULT_SH="./run_*.sh" +echo "Use $DEFAULT_SH" + +ulimit -n 32768 +mkdir logs + +NNODES=4 +NPUS_PER_NODE=16 +# 修改为对应主节点IP +MASTER_ADDR="IP FOR MASTER NODE" +# 修改为当前节点的通信网卡 +SOCKET_IFNAME="Your SOCKET IFNAME" +export HCCL_SOCKET_IFNAME="SOCKET IFNAME FOR CURRENT NODE" +export GLOO_SOCKET_IFNAME="SOCKET IFNAME FOR CURRENT NODE" +# 获取当前IP +CURRENT_IP=$(ifconfig $SOCKET_IFNAME | grep -Eo 'inet (addr:)?([0-9]{1,3}\.){3}[0-9]{1,3}' | awk '{print $NF}') +if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then + # 主节点启动 + ray start --head --port 6766 --dashboard-host=$MASTER_ADDR --node-ip-address=$CURRENT_IP --dashboard-port=8260 --resources='{"NPU": '$NPUS_PER_NODE'}' + + while true; do + ray_status_output=$(ray status) + npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1) + npu_count_int=$(echo "$npu_count" | awk '{print int($1)}') + device_count=$((npu_count_int / $NPUS_PER_NODE)) + + # 判断device_count 是否与 NNODES 相等 + if [ "$device_count" -eq "$NNODES" ]; then + echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script." + ray status + bash $DEFAULT_SH + break + else + echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count" + sleep 5 + fi + done +else + # 子节点尝试往主节点注册 ray 直到成功 + while true; do + # 尝试连接 ray 集群 + ray start --address="$MASTER_ADDR:6766" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP + + # 检查连接是否成功 + ray status + if [ $? -eq 0 ]; then + echo "Successfully connected to the Ray cluster!" + break + else + echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..." + sleep 5 + fi + done +fi + +sleep 600 +~~~ + +DEFAULT_SH:修改为训练所用配置 sh 文件路径。在此案例中修改为 [Qwen2.5-32B](https://github.com/volcengine/verl/blob/main/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh) 路径。 + +NNODES 和 NPUS_PER_NODE:修改为使用节点数量和每个节点 NPU 数量。在此案例中分别为4和16。 + +MASTER_ADDR:修改为对应主节点 IP。即所有节点的 MASTER_ADDR 应该相同。 + +SOCKET_IFNAME, HCCL_SOCKET_IFNAME, GLOO_SOCKET_IFNAME: 修改为对应通信网卡,通信网卡可以通过以下命令获取: + +``` +ifconfig |grep "$(hostname -I |awk '{print $1}'|awk -F '.' '{print $0}')" -B 1|awk -F ':' '{print$1}' | head -1 | tail -1 +``` + +## 性能调优 + +优化从训练、推理、调度和其他四个方面入手。 + +### 训练 + +#### 动态bsz + +~~~bash +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +~~~ + +**这个优化点主要调整上面这两个参数,不过需要注意这两个参数调整的太大会导致OOM** + +**主要调整**`actor_ppo_max_token_len`,调大了会降低训练的耗时,调整`infer_ppo_max_token_len`没有明显的收益,可以不动 + +**这两个参数的作用介绍如下:** + +**这两个参数用于控制动态批处理(dynamic batch size)模式下每个GPU处理的最大token数量** + +- **`actor_ppo_max_token_len`**: Actor模型在PPO更新(前向+反向传播)时每个GPU能处理的最大token数 +- **`infer_ppo_max_token_len`**: 推理阶段(Reference policy和Rollout)计算log概率时每个GPU能处理的最大token数 + +### 推理 + +#### ACLgraph+FULL_DECODE_ONLY + +推理算子下发方面的优化,平均能有`15%~20%`左右的性能收益。 + +先看单开**ACLgraph**,如下: + +~~~bash +# 开启ACLgraph+FULL_DECODE_ONLY(注意:当设置此参数为False时,TASK_QUEUE_ENABLE必须设置为1,不然会报错) +actor_rollout_ref.rollout.enforce_eager=False +actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_capture_sizes='[8,16,32,64,128]' \ +actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode='FULL_DECODE_ONLY' \ +~~~ + +`FULL_DECODE_ONLY`开启成功后有如下输出: + +![FULL_DECODE_ONLY result](https://github.com/wucong25/verl-data/blob/main/ascend_acl_graph.png) + +**`cudagraph_capture_sizes`参数设置指南** + +cudagraph_capture_sizes设置的值对应的是批大小,这里的批大小不是配置里的DP域对应的那个批次大小,这里是相较于vllm来说的批大小,单位为**token** + +默认生成的算法如下,可做参考 + +![cudagraph_capture_sizes](https://github.com/wucong25/verl-data/blob/main/ascend_set_cudagraph_sizes.png) + +##### 推理后端切换 + +使用方式:`export VLLM_ATTENTION_BACKEND=XFORMERS` + +![VLLM_ATTENTION_BACKEND](https://github.com/wucong25/verl-data/blob/main/ascend_vllm_attn_backend.png) + +注:需要注意某些后端在一些比较老的vllm-ascend版本内并不支持 + +##### 使能vllm v1版本 + +使用方式:`export VLLM_USE_V1=1` + +可以常开,一般都是正收益。 + +### 调度 + +#### AIV + +打开方式:设置`export HCCL_OP_EXPANSION_MODE="AIV"` + +HCCL_OP_EXPANSION_MODE环境变量用于配置通信算法的编排展开位置,支持如下取值: + +- AI_CPU:代表通信算法的编排展开位置在Device侧的AI CPU计算单元。 +- AIV:代表通信算法的编排展开位置在Device侧的Vector Core计算单元。 +- HOST:代表通信算法的编排展开位置为Host侧CPU,Device侧根据硬件型号自动选择相应的调度器。 +- HOST_TS:代表通信算法的编排展开位置为Host侧CPU,Host向Device的Task Scheduler下发任务,Device的Task Scheduler进行任务调度执行。 + +下面介绍两种展开机制 + +##### HOST展开 + +image-20260113194257095 + +- 软件栈工作在hostcpu,通信算法展开一个个task +- 每个task调用runtime接口,下发到device的rtsqueue +- STARS从rstqueue上顺序拿取task +- 根据task类型分别调用掉SDMA和RDMA引擎。 + **单算子瓶颈**:hostbound 每个task提交是2~5us,一个通信算子有几百个task,单算子场景不会在device上缓存,下发一个执行一个 + +##### AICpu机制展开 + +image-20260113194333218 + +- host侧不下发一个个task,把通信算子作为一个个kernel,放在通信算子kernel的队列上去。 +- STARS调度kernel队列流上的kernel,把kernel放到AiCPU上去执行。 +- AICPU调用函数(kernel),用一个线程执行kernel 函数,在函数内把通信task展开,把task放到rstqueue上,STARS调用。 +- 降低host和aicpu交互,由几百次降低为一次。 +- task的提交在AICPU上提交,做了提交的部分合并。 + +#### TASK_QUEUE_ENABLE + +**使用方式:**`export TASK_QUEUE_ENABLE=2` + +TASK_QUEUE_ENABLE,下发优化,图模式设置为1(即开启图模式的时候这个要设置为1),非图模式设置为2 + +示意图: + +![ascend task queue](https://github.com/wucong25/verl-data/blob/main/ascend_task_queue2.png) + +##### 绑核优化 + +**使用方式:**`export CPU_AFFINITY_CONF=1` + +详细设置原理可看:https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0059.html + +### 其他 + +以下内容汇总了若干全局环境变量的调优配置。由于这些参数在训练阶段与推理阶段往往都能带来正向收益,且目前尚缺乏足够精细的消融实验来严格区分它们各自对训练或推理的贡献占比,故统一归拢在此,供后续持续监控与进一步拆解分析。 + +#### 使能jemalloc + +使用方式(注意需要先安装jemalloc库):`export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2` + +**安装使用教程:**[MindSpeed-RL/docs/install_guide.md · Ascend/MindSpeed-RL - AtomGit | GitCode](https://gitcode.com/Ascend/MindSpeed-RL/blob/master/docs/install_guide.md#高性能内存库-jemalloc-安装) + +#### 多流复用 + +内存方面有优化 + +使能方式:`export MULTI_STREAM_MEMORY_REUSE=1` + +原理介绍:https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0040.html + +#### VLLM_ASCEND_ENABLE_FLASHCOMM + +使用方式:`export VLLM_ASCEND_ENABLE_FLASHCOMM=1` + +启用昇腾 NPU 特有的FLASHCOMM高速通信优化技术 + +地址:https://vllm-ascend.readthedocs.io/zh-cn/latest/user_guide/release_notes.html + +#### VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE + +使用方式:`export VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE=1` + +启用昇腾 NPU针对大模型推理的稠密计算优化 + +地址:https://vllm-ascend.readthedocs.io/zh-cn/latest/user_guide/release_notes.html + +#### VLLM_ASCEND_ENABLE_PREFETCH_MLP + +使用方式:`export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1` + +启用 MLP 层的权重预取机制 + +image-20251124173132677 + +### verl框架参数设置 + +主要是内存方面的一些设置开关(注意,这个里面的优化都或多或少会导致吞吐量有一定程度的劣化) + +~~~bash +# 梯度检查点 (Gradient Checkpointing) +# 作用: 通过重新计算激活值来节省显存,以计算换内存。在前向传播时不保存中间激活值,反向传播时重新计算,可以显著降低显存占用,允许使用更大的batch size。 +actor_rollout_ref.model.enable_gradient_checkpointing=True + +# 参数卸载 (Parameter Offload) +# 作用: 将模型参数卸载到CPU内存,训练时再加载回GPU。 +actor_rollout_ref.actor.fsdp_config.param_offload=${offload} # True +actor_rollout_ref.ref.fsdp_config.param_offload=${offload} # True + +# 优化器状态卸载 (Optimizer Offload) +# 作用: 将优化器状态(如Adam的动量)卸载到CPU。优化器状态通常占用大量显存(对于Adam,每个参数需要额外8字节),卸载可以节省显存。 +actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} # True + +# 释放推理引擎缓存 (Free Cache Engine) +# 作用: 在训练阶段释放推理引擎的KV cache和权重。这是3D-HybridEngine的核心优化,允许在同一GPU上交替进行推理和训练,显著降低显存需求。 +actor_rollout_ref.rollout.free_cache_engine=True + +# 熵计算优化 +# entropy_checkpointing: 在训练时对熵计算启用重计算,降低显存峰值 +# entropy_from_logits_with_chunking: 分块处理logits张量(如2048 tokens一组),避免一次性加载整个[bsz*seq_len, vocab]张量 +actor_rollout_ref.actor.entropy_checkpointing=True +actor_rollout_ref.ref.entropy_checkpointing=True +actor_rollout_ref.actor.entropy_from_logits_with_chunking=True +actor_rollout_ref.ref.entropy_from_logits_with_chunking=True + +# 推理引擎显存配置 +# gpu_memory_utilization: 控制vLLM使用的GPU显存比例(0.90 = 90%) +# enforce_eager=False: 启用CUDA graphs加速推理,但会占用额外显存 +actor_rollout_ref.rollout.gpu_memory_utilization=0.90 +actor_rollout_ref.rollout.enforce_eager=False +~~~ + +## NPU调优参考文章 + +环境变量相关:[环境变量列表-Ascend Extension for PyTorch6.0.0-昇腾社区](https://www.hiascend.com/document/detail/zh/Pytorch/600/apiref/Envvariables/Envir_001.html) + +社区性能调优教程:[性能调优流程-Ascend Extension for PyTorch6.0.0-昇腾社区](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0001.html) + + + diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 9909dd67581..b1053903510 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -130,12 +130,10 @@ Actor/Rollout/Reference Policy use_kl_loss: False # True for GRPO # Rollout Correction (corrects distribution mismatch between rollout and training) rollout_correction: - rollout_is: token # IS weights: token/sequence/null + rollout_is: token # IS weights rollout_is_threshold: 2.0 # Upper threshold for IS weights - rollout_rs: null # Rejection sampling: token/sequence/geometric/null + rollout_rs: null # Rejection sampling rollout_rs_threshold: null # RS upper threshold - rollout_rs_threshold_lower: null # RS lower threshold - rollout_token_veto_threshold: null # Per-token veto (null to disable) use_torch_compile: True # False to disable torch compile kl_loss_coef: 0.001 # for grpo kl_loss_type: low_var_kl # for grpo @@ -534,12 +532,10 @@ Algorithm target_kl: 0.1 # Rollout Correction rollout_correction: - rollout_is: null # IS weights: token/sequence/null + rollout_is: null # IS weights rollout_is_threshold: 2.0 # Upper threshold for IS weights - rollout_rs: null # Rejection sampling: token/sequence/geometric/null + rollout_rs: null # Rejection sampling rollout_rs_threshold: null # RS upper threshold - rollout_rs_threshold_lower: null # RS lower threshold - rollout_token_veto_threshold: null # Per-token veto (null to disable) - ``gamma``: discount factor - ``lam``: Trade-off between bias and variance in the GAE estimator @@ -557,12 +553,10 @@ Algorithm - ``rollout_correction``: Rollout Correction configuration (nested dict). Set to ``null`` to disable. When enabled, contains: - - ``rollout_is``: IS weights aggregation level: ``token``, ``sequence``, or ``null`` to disable IS weights. + - ``rollout_is``: IS weights aggregation level, ``null`` to disable IS weights. - ``rollout_is_threshold``: Upper threshold for IS weights (e.g., 2.0). - - ``rollout_rs``: Rejection sampling mode: ``token``, ``sequence``, ``geometric``, or ``null`` to disable RS. + - ``rollout_rs``: Rejection sampling mode, ``null`` to disable RS. - ``rollout_rs_threshold``: RS upper threshold. - - ``rollout_rs_threshold_lower``: RS lower threshold (null = auto-reciprocal). - - ``rollout_token_veto_threshold``: Per-token veto threshold for catastrophic outliers (null = disabled). Note: Rollout Correction requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``. diff --git a/docs/examples/gsm8k_example.rst b/docs/examples/gsm8k_example.rst index bc56497be64..1f5bdde7a22 100644 --- a/docs/examples/gsm8k_example.rst +++ b/docs/examples/gsm8k_example.rst @@ -75,7 +75,7 @@ model. --------------------------------- We provide a SFT Trainer using PyTorch FSDP in -`fsdp_sft_trainer.py `_. +`sft_trainer.py `_. Users can customize their own SFT script using our FSDP SFT Trainer. @@ -85,7 +85,7 @@ We also provide various training scripts for SFT on GSM8K dataset in `gsm8k sft set -x - torchrun -m verl.trainer.fsdp_sft_trainer \ + torchrun -m verl.trainer.sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ data.prompt_key=question \ diff --git a/docs/hybrid_flow.rst b/docs/hybrid_flow.rst index 3aa5a4a97cb..3eb1571ca9a 100644 --- a/docs/hybrid_flow.rst +++ b/docs/hybrid_flow.rst @@ -217,7 +217,7 @@ Important code files in the repository are organized as below: main_ppo.py # the entrypoint for RL training ppo ray_trainer.py # the training loop for RL algorithms such as PPO - fsdp_sft_trainer.py # the SFT trainer with FSDP backend + sft_trainer.py # the SFT trainer with FSDP backend config generation.yaml # configuration template for rollout ppo_trainer.yaml # configuration template for the RL trainer diff --git a/docs/index.rst b/docs/index.rst index 3a42483ebc2..680a446974e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -81,6 +81,7 @@ verl is fast with: algo/rollout_corr.md algo/rollout_corr_math.md algo/otb.md + algo/dppo.md .. toctree:: :maxdepth: 1 @@ -153,6 +154,7 @@ verl is fast with: ascend_tutorial/dockerfile_build_guidance.rst ascend_tutorial/ascend_sglang_quick_start.rst ascend_tutorial/examples/gspo_optimization_practice.md + ascend_tutorial/examples/ascend_performance_analysis_guide.md ascend_tutorial/examples/dapo_multi_model_optimization_practice.md ascend_tutorial/examples/ascend_sglang_best_practices.rst ascend_tutorial/examples/ascend_retool_best_pratice.rst diff --git a/docs/start/quickstart.rst b/docs/start/quickstart.rst index c0be6a6b30b..c72b531668b 100644 --- a/docs/start/quickstart.rst +++ b/docs/start/quickstart.rst @@ -53,7 +53,7 @@ Step 2: Download a model for post-training In this example, we start with the ``Qwen2.5-0.5B-Instruct`` model. -If you want to perform SFT before RL, refer to the :doc:`Complete GSM8K Example<../examples/gsm8k_example>`, the `sft directory `_ and `SFT Trainer `_ for further details. +If you want to perform SFT before RL, refer to the :doc:`Complete GSM8K Example<../examples/gsm8k_example>`, the `sft directory `_ and `SFT Trainer `_ for further details. .. code-block:: bash diff --git a/examples/dppo_trainer/dppo.md b/examples/dppo_trainer/dppo.md new file mode 100644 index 00000000000..36802990bd5 --- /dev/null +++ b/examples/dppo_trainer/dppo.md @@ -0,0 +1,94 @@ +# Divergence Proximal Policy Optimization (DPPO) + + +
+ +## Rethinking the Trust Region in LLM Reinforcement Learning + +[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white )](https://arxiv.org/pdf/2602.04879) +[![Github](https://img.shields.io/badge/Stable_RL-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/sail-sg/Stable-RL) +[![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/QPHutu/status/2019435642539897303) + +
+ + +## ✨Getting started + +1. Prepare the datasets by running [prepare_dapo_data.sh](https://github.com/verl-project/verl-recipe/blob/3490a22a0a3adeb7e4787fe70b1060b642efbae4/dapo/prepare_dapo_data.sh): + +```bash +bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default +``` + +2. Prepare the model: + +```bash +hf download Qwen/Qwen3-30B-A3B-Base --local-dir ${HOME}/verl/models/Qwen3-30B-A3B-Base +``` + +3. Run the script: + +```bash +# run DPPO-Binary-KL +LOSS_MODE=dppo_kl bash examples/dppo_trainer/run_qwen30b_dppo.sh + +# run DPPO-Binary-TV +LOSS_MODE=dppo_tv bash examples/dppo_trainer/run_qwen30b_dppo.sh + +# run GRPO baseline +LOSS_MODE=vanilla CLIP_LOW=0.2 CLIP_HIGH=0.2 bash examples/dppo_trainer/run_qwen30b_dppo.sh +# or GRPO with clip higher +LOSS_MODE=vanilla CLIP_LOW=0.2 CLIP_HIGH=0.28 bash examples/dppo_trainer/run_qwen30b_dppo.sh +``` + +## 📖Introduction + +
+ issue +
+ +Comparison of **PPO** and the proposed **DPPO** (the Binary-TV variant). **(Left)** The surrogate objective and corresponding masks for PPO and DPPO. PPO (and variants like GRPO) employs a heuristic mask based on the probability ratio. In contrast, DPPO utilizes a more principled mask based on a direct approximation of policy divergence (e.g., Total Variation), ensuring updates stay within a theoretically grounded trust region. **(Right)** Experimental results on the AIME24 using Qwen3-30B-A3B-Base. DPPO significantly outperforms GRPO baselines, achieving superior training stability and final performance even without rollout routing replay (R3). + +
+ issue +
+ +DPPO variants achieve stable training while controlling the training-inference mismatch at a low level. In contrast, methods without a trust region (PG-IS, CISPO) or with a misspecified one (MiniRL) suffer from growing mismatch and eventual collapse. + +
+ issue +
+ +The plots show numerical differences between a training and an inference engine for Qwen3-30B-A3B-Base with identical parameters. **(Left)** The probability ratio (used in PPO) is highly volatile for low-probability tokens. **(Right)** In contrast, the TV divergence is more stable. This highlights a key flaw of PPO's clipping mechanism: it **over-penalizes low-probability tokens**, which can slow down learning; and **under-penalizes high-probability tokens**, which can permit large, destabilizing updates. + + +
+ issue +
+ +The most frequently clipped tokens (by GRPO) are important to the reasoning task! +They are dominated by: +- numbers, like 1, 4 +- mathematical symbols, like +, -, = +- reasoning and structural Words: Wait, Thus, Next + +## Top-K divergence approximation + +We only implement the DPPO-Binary-TV/DPPO-Binary-KL here due to their simplicity. + +For the TopK divergence approximation, please refer to the [the original repo](https://github.com/sail-sg/Stable-RL) for a complete implementation. + +## Citation +If you find our works useful for your research, please consider citing: + +```bibtex +@article{qi2026dppo, + title={Rethinking the Trust Region in LLM Reinforcement Learning}, + author={Qi, Penghui and Zhou, Xiangxin and Liu, Zichen and Pang, Tianyu and Du, Chao and Lin, Min and Lee, Wee Sun}, + journal={arXiv preprint arXiv:2602.04879}, + year={2026} +} +``` + +## 🌻Acknowledgement +We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) and [sglang](https://github.com/sgl-project/sglang) for inference. Our models are trained primarily on [Qwen3 family](https://huggingface.co/collections/Qwen/qwen3). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions! diff --git a/examples/dppo_trainer/run_qwen30b_dppo.sh b/examples/dppo_trainer/run_qwen30b_dppo.sh new file mode 100644 index 00000000000..7c3783a2ffd --- /dev/null +++ b/examples/dppo_trainer/run_qwen30b_dppo.sh @@ -0,0 +1,262 @@ +# run Qwen3-30B-A3B-Base on dapo-math-17k dataset +set -x + +# ================================ DPPO Specific Parameters =========================== + +# Why from GRPO to DPPO? +""" +The ratio clipping mechanism in GRPO/PPO is structurally ill-suited due to the large, +long-tailed vocabularies inherent to LLMs. It over-penalizes low-probability tokens +and under-penalizes high-probability ones, leading to training inefficiency and +instability. For example, increasing a rare token’s probability from 1e−5 to 1e−3 +generates a massive ratio of 100 that triggers clipping, even though the actual +divergence is negligible. Conversely, small ratio changes on high-probability tokens +can make catastrophic shifts in probability mass (e.g., a drop from 0.99 to 0.8), yet +it often remains unpenalized by the clipping mechanism. + +DPPO addresses this issue by using a divergence-based clipping mechanism, achieving +superior training stability and final performance compared to existing methods. + +DPPO paper: https://arxiv.org/pdf/2602.04879 +""" + +LOSS_MODE=${LOSS_MODE:-"dppo_tv"} + +if [[ $LOSS_MODE == "dppo_kl" ]]; then + # The KL divergence threshold for DPPO. + clip_ratio=0.05 + clip_ratio_low=${CLIP_LOW:-0.05} + clip_ratio_high=${CLIP_HIGH:-0.05} +elif [[ $LOSS_MODE == "dppo_tv" ]]; then + # The TV divergence threshold for DPPO. + clip_ratio=0.15 + clip_ratio_low=${CLIP_LOW:-0.15} + clip_ratio_high=${CLIP_HIGH:-0.15} +elif [[ $LOSS_MODE == "vanilla" ]]; then + # GRPO baseline + clip_ratio=0.2 + clip_ratio_low=${CLIP_LOW:-0.2} + clip_ratio_high=${CLIP_HIGH:-0.28} +else + echo "Invalid loss mode: $LOSS_MODE" + exit 1 +fi + +# Disable dual-clip PPO and TIS for a fair comparison between GRPO and DPPO. +clip_ratio_c=10000.0 + +# ===================================== Algorithm ===================================== +adv_estimator=grpo + +# We recommand directly clipping the ratio/divergence with respect to the original +# rollout policy (implemented by bypass_mode=True), instead of the recomputed one. +# This can not only save the computation cost, but also improve the training stability +# for both GRPO and DPPO by controlling the training-inference mismatch at a low level. +# See Section 5.2 in https://arxiv.org/pdf/2602.04879 for more details. +bypass_mode=True + +# We recommand using Dr.GRPO to remove the length and difficulty bias in original GRPO. +# See Section 3.1 in https://arxiv.org/pdf/2503.20783 for more details. +norm_adv_by_std_in_grpo=False # remove the difficulty bias +loss_agg_mode="seq-mean-token-sum-norm" # remove the length bias + +# reference policy +use_kl_in_reward=False +kl_coef=0.001 +use_kl_loss=False +kl_loss_coef=0.001 + +actor_lr=1e-6 +critic_lr=2e-6 +gae_gamma=1.0 +gae_lam=0.95 +critic_warmup=0 + + +# ================================== Data/Model/Config ================================= + +# Node Info +NNODES=${NNODES:-2} + +# wandb +backend=megatron # fsdp, fsdp2, megatron +project_name=Qwen3-30B-A3B-Base-dapo-math-17k +experiment_name="${backend}-${NNODES}nodes-${LOSS_MODE}-low${clip_ratio_low}-high${clip_ratio_high}" + +# Paths +DATA_ROOT=${DATA_ROOT:-"${HOME}/verl"} +CKPTS_DIR=${CKPTS_DIR:-"${DATA_ROOT}/ckpts/${project_name}/${experiment_name}"} +MODEL_PATH=${MODEL_PATH:-"${DATA_ROOT}/models/Qwen3-30B-A3B-Base"} +TRAIN_FILE=${TRAIN_FILE:-"${DATA_ROOT}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${DATA_ROOT}/data/aime-2024.parquet"} + + +actor_model_path=$MODEL_PATH +critic_model_path=$MODEL_PATH + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +train_batch_size=256 +ppo_mini_batch_size=32 +n_resp_per_prompt=16 +n_resp_per_prompt_val=1 + +# ===================================== Training ====================================== +actor_max_token_len_per_gpu=$(((max_prompt_length + max_response_length) * 1)) +critic_max_token_len_per_gpu=$(((max_prompt_length + max_response_length) * 1)) + +# FSDP parallelism config +USP_SIZE=4 +ACTOR_FSDP_CONFIG=" + actor_rollout_ref.actor.fsdp_config.strategy=$backend \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$USP_SIZE" + +# Megatron parallelism config +TP_SIZE=2 +CP_SIZE=1 +PP_SIZE=1 +VPP_SIZE=null +EP_SIZE=8 +ETP_SIZE=1 +ACTOR_MEGATRON_CONFIG=" + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP_SIZE \ + actor_rollout_ref.actor.megatron.context_parallel_size=$CP_SIZE \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP_SIZE \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$VPP_SIZE \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP_SIZE \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP_SIZE \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + actor_rollout_ref.actor.megatron.use_mbridge=True" + +# Actor model config +ACTOR_CONFIG=" + actor_rollout_ref.actor.optim.lr=$actor_lr \ + actor_rollout_ref.model.path=$actor_model_path \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ + actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ + actor_rollout_ref.actor.clip_ratio=$clip_ratio \ + actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \ + actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \ + actor_rollout_ref.actor.clip_ratio_c=$clip_ratio_c \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.calculate_entropy=True \ + actor_rollout_ref.actor.policy_loss.loss_mode=${LOSS_MODE} \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu" + +# Critic model config +CIRITC_CONFIG=" + critic.optim.lr=$critic_lr \ + critic.model.path=$critic_model_path \ + critic.model.use_remove_padding=True \ + critic.ppo_max_token_len_per_gpu=$critic_max_token_len_per_gpu \ + critic.ulysses_sequence_parallel_size=$USP_SIZE" + +CRITIC_FSDP_CONFIG="${ACTOR_FSDP_CONFIG//actor_rollout_ref.actor/critic.model}" +CRITIC_MEGATRON_CONFIG="${ACTOR_MEGATRON_CONFIG//actor_rollout_ref.actor/critic}" + +if [[ $backend == "megatron" ]]; then + CONFIG_NAME=ppo_megatron_trainer + ACTOR_CONFIG="$ACTOR_CONFIG $ACTOR_MEGATRON_CONFIG" + if [[ $adv_estimator == "gae" ]]; then + CIRITC_CONFIG="$CIRITC_CONFIG $CRITIC_MEGATRON_CONFIG" + else + CIRITC_CONFIG="" + fi +else # fsdp, fsdp2 + CONFIG_NAME=ppo_trainer + ACTOR_CONFIG="$ACTOR_CONFIG $ACTOR_FSDP_CONFIG" + if [[ $adv_estimator == "gae" ]]; then + CIRITC_CONFIG="$CIRITC_CONFIG $CRITIC_FSDP_CONFIG" + else + CIRITC_CONFIG="" + fi +fi + +# ===================================== Inference ===================================== +rollout_name=vllm +if [ "$rollout_name" = "vllm" ]; then + export VLLM_USE_V1=1 +fi +infer_tp=4 +infer_dp=1 +infer_ep=1 +gpu_memory_utilization=0.7 + +ROLLOUT_CONFIG=" + actor_rollout_ref.rollout.name=$rollout_name \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \ + actor_rollout_ref.rollout.data_parallel_size=$infer_dp \ + actor_rollout_ref.rollout.expert_parallel_size=$infer_ep \ + actor_rollout_ref.rollout.gpu_memory_utilization=$gpu_memory_utilization \ + actor_rollout_ref.rollout.n=$n_resp_per_prompt \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.rollout.val_kwargs.top_p=1.0 \ + actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val" + +# ===================================== Reward ===================================== +REWARD_CONFIG=" + reward.reward_manager.name=dapo \ + +reward.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward.reward_kwargs.max_resp_len=${max_response_length}" + +python3 -m verl.trainer.main_ppo \ + --config-path=./config \ + --config-name=$CONFIG_NAME \ + algorithm.adv_estimator=$adv_estimator \ + algorithm.use_kl_in_reward=$use_kl_in_reward \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + algorithm.gamma=$gae_gamma \ + algorithm.lam=$gae_lam \ + algorithm.rollout_correction.bypass_mode=$bypass_mode \ + algorithm.norm_adv_by_std_in_grpo=$norm_adv_by_std_in_grpo \ + data.train_files="$TRAIN_FILE" \ + data.val_files="$TEST_FILE" \ + data.return_raw_chat=True \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=False \ + data.filter_overlong_prompts_workers=64 \ + data.truncation='error' \ + trainer.use_legacy_worker_impl=disable \ + trainer.critic_warmup=$critic_warmup \ + trainer.logger=['console','wandb'] \ + trainer.project_name=$project_name \ + trainer.experiment_name=$experiment_name \ + trainer.default_local_dir=$CKPTS_DIR \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$NNODES \ + trainer.val_before_train=False \ + trainer.log_val_generations=100 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=500 \ + $ACTOR_CONFIG \ + $CIRITC_CONFIG \ + $ROLLOUT_CONFIG \ + $REWARD_CONFIG diff --git a/examples/grpo_trainer/run_qwen3_vl-8b_npu.sh b/examples/grpo_trainer/run_qwen3_vl-8b_npu.sh new file mode 100644 index 00000000000..1a6b6b55d4f --- /dev/null +++ b/examples/grpo_trainer/run_qwen3_vl-8b_npu.sh @@ -0,0 +1,87 @@ +set -x + +project_name='GRPO-Qwen3_vl' +exp_name='GRPO-Qwen3_vl-8B-npu' +gen_tp=1 +sp_size=1 +ENGINE=${1:-vllm} +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-VL-8B-Instruct"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/geo3k/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/geo3k/test.parquet"} + +# Rollout Correction parameters (sequence-level TIS + geometric RS) +rollout_is=sequence +rollout_is_threshold=2.0 +rollout_is_batch_normalize=true +rollout_rs=token_k1 +rollout_rs_threshold=0.6_1.6 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.rollout.max_num_batched_tokens=20000 \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=32 \ + actor_rollout_ref.actor.fsdp_config.reshard_after_forward=True \ + actor_rollout_ref.ref.fsdp_config.reshard_after_forward=True \ + actor_rollout_ref.actor.fsdp_config.entropy_checkpointing=True \ + actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.name=$ENGINE \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$sp_size \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=$sp_size \ + actor_rollout_ref.rollout.n=5 \ + data.shuffle=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=4 \ + trainer.default_local_dir=${CKPTS_DIR} \ + trainer.resume_mode=auto \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_is_batch_normalize=${rollout_is_batch_normalize} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + trainer.val_before_train=True \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen3_vl_30b_vllm_fsdp_npu.sh b/examples/grpo_trainer/run_qwen3_vl_30b_vllm_fsdp_npu.sh new file mode 100644 index 00000000000..d42cf1944d8 --- /dev/null +++ b/examples/grpo_trainer/run_qwen3_vl_30b_vllm_fsdp_npu.sh @@ -0,0 +1,87 @@ +set -x + +project_name='GRPO-Qwen3_vl' +exp_name='GRPO-Qwen3_vl-30B-npu' +gen_tp=8 +sp_size=2 +ENGINE=${1:-vllm} +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-VL-30B-A3B-Instruct"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/geo3k/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/geo3k/test.parquet"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} + +# Rollout Correction parameters +rollout_is=sequence +rollout_is_threshold=2.0 +rollout_is_batch_normalize=true +rollout_rs=token_k1 +rollout_rs_threshold=0.6_1.6 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=32 \ + actor_rollout_ref.actor.fsdp_config.reshard_after_forward=True \ + actor_rollout_ref.ref.fsdp_config.reshard_after_forward=True \ + actor_rollout_ref.actor.fsdp_config.entropy_checkpointing=True \ + actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$sp_size \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=$sp_size \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.max_num_batched_tokens=20000 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + algorithm.use_kl_in_reward=False \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_is_batch_normalize=${rollout_is_batch_normalize} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + trainer.critic_warmup=0 \ + trainer.logger='["console", "wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=2 \ + trainer.default_local_dir=${CKPTS_DIR} \ + trainer.resume_mode=auto \ + trainer.val_before_train=True \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 \ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen3moe-30b_megatron_lora_fp16.sh b/examples/grpo_trainer/run_qwen3moe-30b_megatron_lora_fp16.sh new file mode 100644 index 00000000000..7b06f48de9c --- /dev/null +++ b/examples/grpo_trainer/run_qwen3moe-30b_megatron_lora_fp16.sh @@ -0,0 +1,145 @@ +#!/usr/bin/env bash +set -xeuo pipefail +pwd=`pwd` + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +TP=${TP:-2} +PP=${PP:-2} +CP=${CP:-2} +EP=${EP:-4} +ETP=${ETP:-1} + +ALL_OFFLOAD=${ALL_OFFLOAD:-True} + +optimizer_offload_fraction=1. + +dtype="float16" # ["bfloat16", "float16"] +rollout_name="vllm" +project_name='verl_grpo_example_gsm8k_math_fp16' +exp_name='qwen3_30b_a3b_megatron_lora' +adv_estimator=grpo + +# Paths +MODEL_PATH=$HOME/Qwen/Qwen3-30B-A3B-Instruct-2507 +CKPTS_DIR=${pwd}/ckpt/${exp_name} + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet + +########################### Parameter Arrays ########################### + +DATA=( + data.train_files=${gsm8k_train_path} + data.val_files=${gsm8k_test_path} + data.train_batch_size=128 + data.max_prompt_length=1024 + data.max_response_length=1024 + data.truncation='error' + data.filter_overlong_prompts=True + data.shuffle=False + data.return_raw_chat=$return_raw_chat + data.filter_overlong_prompts_workers=128 +) + +MODEL=( + actor_rollout_ref.model.path=${MODEL_PATH} + actor_rollout_ref.model.lora.rank=16 + actor_rollout_ref.model.lora.alpha=32 + actor_rollout_ref.model.lora.dtype=${dtype} + actor_rollout_ref.model.use_fused_kernels=True +) + +ACTOR=( + actor_rollout_ref.actor.optim.lr=3e-6 + actor_rollout_ref.actor.ppo_mini_batch_size=16 + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 + actor_rollout_ref.actor.megatron.use_mbridge=True + actor_rollout_ref.actor.megatron.vanilla_mbridge=False + actor_rollout_ref.actor.use_dynamic_bsz=True + actor_rollout_ref.actor.use_kl_loss=True + actor_rollout_ref.actor.kl_loss_coef=0.001 + actor_rollout_ref.actor.kl_loss_type=low_var_kl + actor_rollout_ref.actor.entropy_coeff=0 + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${TP} + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${PP} + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP} + actor_rollout_ref.actor.megatron.context_parallel_size=${CP} + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP} + actor_rollout_ref.actor.megatron.param_offload=${ALL_OFFLOAD} + actor_rollout_ref.actor.megatron.optimizer_offload=${ALL_OFFLOAD} + actor_rollout_ref.actor.megatron.grad_offload=${ALL_OFFLOAD} + actor_rollout_ref.actor.megatron.dtype=${dtype} + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True + +actor_rollout_ref.actor.megatron.override_ddp_config.grad_reduce_in_fp32=True + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=${ALL_OFFLOAD} +) + +ROLLOUT=( + actor_rollout_ref.rollout.tensor_model_parallel_size=8 + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True + actor_rollout_ref.rollout.name=${rollout_name} + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 + actor_rollout_ref.rollout.enforce_eager=True + actor_rollout_ref.rollout.free_cache_engine=True + actor_rollout_ref.rollout.n=4 + actor_rollout_ref.rollout.dtype=${dtype} +) + +REF=( + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True + actor_rollout_ref.ref.megatron.dtype=${dtype} + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${TP} + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${PP} + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${EP} + actor_rollout_ref.ref.megatron.context_parallel_size=${CP} + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${ETP} + actor_rollout_ref.ref.megatron.param_offload=${ALL_OFFLOAD} +) + +ALGORITHM=( + algorithm.adv_estimator=${adv_estimator} +) + +TRAINER=( + trainer.critic_warmup=0 + trainer.logger='["console","wandb"]' + trainer.project_name=${project_name} + trainer.experiment_name=${exp_name} + trainer.n_gpus_per_node=8 + trainer.nnodes=1 + trainer.save_freq=20 + trainer.test_freq=5 + trainer.total_epochs=15 + trainer.val_before_train=False + trainer.max_actor_ckpt_to_keep=1 + trainer.default_local_dir="${CKPTS_DIR}" + trainer.log_val_generations=10 +) + +########################### Launch ########################### + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + "${DATA[@]}" \ + "${ALGORITHM[@]}" \ + "${MODEL[@]}" \ + "${ROLLOUT[@]}" \ + "${ACTOR[@]}" \ + "${REF[@]}" \ + "${TRAINER[@]}" \ + 2>&1 | tee ${pwd}/log/${exp_name}_$(date +'%Y%m%d_%H%M%S').log \ No newline at end of file diff --git a/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh b/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh index 69fbf4251ee..2aa1c903380 100644 --- a/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh +++ b/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh @@ -40,8 +40,8 @@ WORKING_DIR=${WORKING_DIR:-"${PWD}"} RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} # Data Length Configuration -max_prompt_length=$((1024 * 16)) -max_response_length=$((1024 * 16)) +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) # Training Batch Configuration train_prompt_bsz=256 @@ -60,20 +60,21 @@ clip_ratio_low=0.0003 clip_ratio_high=0.0004 loss_agg_mode="seq-mean-token-mean" +# FSDP Parallelism Configuration +actor_strategy=fsdp2 +ref_strategy=fsdp2 +sp_size=4 +fsdp_size=-1 + # Performance and Memory Management Configuration offload=True use_dynamic_bsz=True actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) -# FSDP Parallelism Configuration -actor_strategy=fsdp2 -ref_strategy=fsdp2 -sp_size=4 -fsdp_size=-1 # vLLM Configuration gen_tp=4 -gpu_memory_utilization=0.9 +gpu_memory_utilization=0.7 max_model_len=$((max_prompt_length + max_response_length)) max_num_batched_tokens=$((max_prompt_length + max_response_length)) @@ -161,7 +162,7 @@ ROLLOUT_CONFIG=( actor_rollout_ref.rollout.enable_chunked_prefill=True actor_rollout_ref.rollout.enforce_eager=False actor_rollout_ref.rollout.free_cache_engine=True - +actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_capture_sizes="[8, 16, 32, 64, 128, 192, 256, 384]" + +actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_capture_sizes="[8, 16, 32, 64, 128, 192, 256]" +actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode="FULL_DECODE_ONLY" actor_rollout_ref.rollout.val_kwargs.n=1 actor_rollout_ref.rollout.val_kwargs.do_sample=True diff --git a/examples/sapo_trainer/run_qwen3_8b_sapo_npu.sh b/examples/sapo_trainer/run_qwen3_8b_sapo_npu.sh new file mode 100644 index 00000000000..a9320b8a54d --- /dev/null +++ b/examples/sapo_trainer/run_qwen3_8b_sapo_npu.sh @@ -0,0 +1,95 @@ +set -euxo pipefail + +ulimit -n 32768 + +## Basic Environment Settings +export RAY_DEDUP_LOGS=0 +export HYDRA_FULL_ERROR=1 +export TASK_QUEUE_ENABLE=1 +export HCCL_EXEC_TIMEOUT=3600 +export HCCL_CONNECT_TIMEOUT=3600 +export HCCL_ASYNC_ERROR_HANDLING=0 +export CPU_AFFINITY_CONF=1 +export VLLM_USE_V1=1 + +project_name='SAPO-Qwen3' +exp_name='SAPO-Qwen3-8B-npu' +gen_tp=2 +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/Qwen3-8B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/dataset/dapo_processed/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/dataset/aime-24_processed/train.parquet"} + +# reference policy +use_kl_in_reward=False +kl_coef=0.001 +use_kl_loss=False +kl_loss_coef=0.001 + +# ------Algorithm settings------- +# Positive and negative tau for smoothing function in SAPO (https://arxiv.org/pdf/2511.20347) +# default values used in the paper with Qwen3-30B-A3B-Base +# clipping is not used in SAPO! + +loss_mode=sapo # explicitly specify sapo! default is vanilla and is not compatible with SAPO. It uses clipping instead of smoothing. + +tau_pos=1.0 +tau_neg=1.05 + +gae_gamma=1.0 +gae_lam=0.95 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + algorithm.use_kl_in_reward=$use_kl_in_reward \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + algorithm.gamma=$gae_gamma \ + algorithm.lam=$gae_lam \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.filter_overlong_prompts_workers=64 \ + data.truncation='error' \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.actor.tau_pos=$tau_pos \ + actor_rollout_ref.actor.tau_neg=$tau_neg \ + actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ + actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + trainer.critic_warmup=0 \ + trainer.logger='["console"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.default_local_dir=${CKPTS_DIR} \ + trainer.resume_mode=auto \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + trainer.val_before_train=True \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 diff --git a/examples/sft/gsm8k/run_deepseek_6b7.sh b/examples/sft/gsm8k/run_deepseek_6b7.sh index 8a067f05d50..f4a654c12c0 100644 --- a/examples/sft/gsm8k/run_deepseek_6b7.sh +++ b/examples/sft/gsm8k/run_deepseek_6b7.sh @@ -12,15 +12,14 @@ save_path=$2 shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ + -m verl.trainer.sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ + data.messages_key=messages \ data.micro_batch_size_per_gpu=4 \ - model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \ + optim.lr=1e-4 \ + engine=fsdp \ + model.path=deepseek-ai/deepseek-coder-6.7b-instruct \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ diff --git a/examples/sft/gsm8k/run_gemma_2b.sh b/examples/sft/gsm8k/run_gemma_2b.sh index 5b59893d258..34e656f7540 100644 --- a/examples/sft/gsm8k/run_gemma_2b.sh +++ b/examples/sft/gsm8k/run_gemma_2b.sh @@ -14,15 +14,14 @@ save_path=$2 shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ + -m verl.trainer.sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ + data.messages_key=messages \ data.micro_batch_size_per_gpu=4 \ - model.partial_pretrain=google/gemma-2b-it \ + model.path=google/gemma-2b-it \ + optim.lr=1e-4 \ + engine=fsdp \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-gemma-2b-it \ diff --git a/examples/sft/gsm8k/run_gemma_7b.sh b/examples/sft/gsm8k/run_gemma_7b.sh index fe2bc3a6f39..868ae07e461 100644 --- a/examples/sft/gsm8k/run_gemma_7b.sh +++ b/examples/sft/gsm8k/run_gemma_7b.sh @@ -12,15 +12,14 @@ save_path=$2 shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ + -m verl.trainer.sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - data.prompt_dict_keys=['question'] \ - data.response_dict_keys=['answer'] \ + data.messages_key=messages \ data.micro_batch_size_per_gpu=4 \ - model.partial_pretrain=google/gemma-1.1-7b-it \ + optim.lr=1e-4 \ + engine=fsdp \ + model.path=google/gemma-1.1-7b-it \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-gemma-1.1-7b-it \ diff --git a/examples/sft/gsm8k/run_qwen3_8b_sft_peft_sp2_npu.sh b/examples/sft/gsm8k/run_qwen3_8b_sft_peft_sp2_npu.sh index 7de7ebd67e4..c8e3fa7fdc3 100644 --- a/examples/sft/gsm8k/run_qwen3_8b_sft_peft_sp2_npu.sh +++ b/examples/sft/gsm8k/run_qwen3_8b_sft_peft_sp2_npu.sh @@ -12,16 +12,15 @@ save_path=$2 shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ + -m verl.trainer.sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - optim.lr=1e-4 \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ data.micro_batch_size_per_gpu=64 \ - model.partial_pretrain=Qwen/Qwen3-8B \ + optim.lr=1e-4 \ + engine=fsdp \ + engine.ulysses_sequence_parallel_size=2 \ + model.path=Qwen/Qwen3-8B \ + model.use_remove_padding=true \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-qwen3-8b-instruct \ @@ -29,7 +28,4 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ trainer.total_epochs=2 $@ \ model.lora_rank=32 \ model.lora_alpha=16 \ - model.target_modules=all-linear \ - model.strategy=fsdp \ - ulysses_sequence_parallel_size=2 \ - use_remove_padding=true + model.target_modules=all-linear diff --git a/examples/sft/gsm8k/run_qwen_05_peft.sh b/examples/sft/gsm8k/run_qwen_05_peft.sh index 3a7d4455807..7ac5fc303f5 100644 --- a/examples/sft/gsm8k/run_qwen_05_peft.sh +++ b/examples/sft/gsm8k/run_qwen_05_peft.sh @@ -14,16 +14,13 @@ save_path=$2 shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ + -m verl.trainer.sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - optim.lr=1e-4 \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ data.micro_batch_size_per_gpu=4 \ - model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + optim.lr=1e-4 \ + engine=fsdp \ + model.path=Qwen/Qwen2.5-0.5B-Instruct \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \ diff --git a/examples/sft/gsm8k/run_qwen_05_sp2.sh b/examples/sft/gsm8k/run_qwen_05_sp2.sh index 7210a5a4038..33320d0d02f 100644 --- a/examples/sft/gsm8k/run_qwen_05_sp2.sh +++ b/examples/sft/gsm8k/run_qwen_05_sp2.sh @@ -12,20 +12,18 @@ save_path=$2 shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ + -m verl.trainer.sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - optim.lr=1e-4 \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ + data.messages_key=messages \ data.micro_batch_size=4 \ - model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + optim.lr=1e-4 \ + engine=fsdp \ + engine.ulysses_sequence_parallel_size=2 \ + model.path=Qwen/Qwen2.5-0.5B-Instruct \ + model.use_remove_padding=true \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2 \ trainer.logger=console \ - trainer.total_training_steps=1 $@ \ - ulysses_sequence_parallel_size=2 \ - use_remove_padding=true + trainer.total_training_steps=1 $@ diff --git a/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh b/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh index 1c5cd591f14..4335aea5840 100644 --- a/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh +++ b/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh @@ -12,20 +12,18 @@ save_path=$2 shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ + -m verl.trainer.sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - optim.lr=1e-4 \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ + data.messages_key=messages \ data.micro_batch_size=4 \ - model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + optim.lr=1e-4 \ + engine=fsdp \ + engine.ulysses_sequence_parallel_size=2 \ + model.path=Qwen/Qwen2.5-0.5B-Instruct \ model.use_liger=True \ + model.use_remove_padding=true \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2-liger \ - trainer.logger=console $@ \ - ulysses_sequence_parallel_size=2 \ - use_remove_padding=true + trainer.logger=console $@ diff --git a/examples/sft/gsm8k/run_seed_oss_36b_sft.sh b/examples/sft/gsm8k/run_seed_oss_36b_sft.sh index 35c1d6c6d34..fe6f2822760 100644 --- a/examples/sft/gsm8k/run_seed_oss_36b_sft.sh +++ b/examples/sft/gsm8k/run_seed_oss_36b_sft.sh @@ -12,20 +12,17 @@ save_path=$2 shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ + -m verl.trainer.sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - optim.lr=1e-4 \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ data.micro_batch_size=4 \ - model.partial_pretrain=ByteDance-Seed/Seed-OSS-36B-Base \ + optim.lr=1e-4 \ + engine=fsdp \ + engine.ulysses_sequence_parallel_size=2 \ + model.path=ByteDance-Seed/Seed-OSS-36B-Base \ + model.use_remove_padding=true \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-seed-oss-36b \ trainer.logger=console \ - trainer.total_training_steps=1 \ - ulysses_sequence_parallel_size=2 \ - use_remove_padding=true $@ + trainer.total_training_steps=1 $@ diff --git a/examples/sft/multiturn/run_qwen_05_sp2.sh b/examples/sft/multiturn/run_qwen_05_sp2.sh index 5e1fc47e9c5..511eb600ba0 100644 --- a/examples/sft/multiturn/run_qwen_05_sp2.sh +++ b/examples/sft/multiturn/run_qwen_05_sp2.sh @@ -13,7 +13,7 @@ save_path=$2 shift 2 torchrun --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ + -m verl.trainer.sft_trainer \ data.train_files=$HOME/data/multiturn/train.parquet \ data.val_files=$HOME/data/multiturn/test.parquet \ data.multiturn.enable=true \ diff --git a/pyproject.toml b/pyproject.toml index 89bf6798a8b..a597421e920 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,4 +110,5 @@ verl = [ "version/*", "trainer/config/*.yaml", "trainer/config/*/*.yaml", + "experimental/*/config/*.yaml", ] diff --git a/recipe b/recipe index 21892b92769..3490a22a0a3 160000 --- a/recipe +++ b/recipe @@ -1 +1 @@ -Subproject commit 21892b9276936efab5375c3f6b8415e472ef7118 +Subproject commit 3490a22a0a3adeb7e4787fe70b1060b642efbae4 diff --git a/requirements-npu.txt b/requirements-npu.txt index ea197c98f31..fada1f839c2 100644 --- a/requirements-npu.txt +++ b/requirements-npu.txt @@ -18,4 +18,4 @@ torchdata einops qwen_vl_utils hf_transfer -triton-ascend==3.2.0rc4 \ No newline at end of file +triton-ascend==3.2.0 diff --git a/scripts/generate_trainer_config.sh b/scripts/generate_trainer_config.sh index 06bb371d06c..c4c89cdbdba 100755 --- a/scripts/generate_trainer_config.sh +++ b/scripts/generate_trainer_config.sh @@ -7,30 +7,31 @@ CONFIG_SPECS=( "ppo_trainer:_generated_ppo_trainer.yaml:" "ppo_megatron_trainer:_generated_ppo_megatron_trainer.yaml:--config-name=ppo_megatron_trainer.yaml" "ppo_trainer:_generated_ppo_veomni_trainer.yaml:model_engine=veomni" + "ppo_trainer:_generated_ppo_torchtitan_trainer.yaml:model_engine=torchtitan" ) generate_config() { local config_name="$1" local output_file="$2" local config_arg="$3" - + local target_cfg="verl/trainer/config/${output_file}" local tmp_header=$(mktemp) local tmp_cfg=$(mktemp) - + echo "# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'" > "$tmp_header" echo "# in which it invokes 'python3 scripts/print_cfg.py --cfg job ${config_arg}' to flatten the 'verl/trainer/config/${config_name}.yaml' config fields into a single file." >> "$tmp_header" echo "# Do not modify this file directly." >> "$tmp_header" echo "# The file is usually only for reference and never used." >> "$tmp_header" echo "" >> "$tmp_header" - + python3 scripts/print_cfg.py --cfg job ${config_arg} > "$tmp_cfg" - + cat "$tmp_header" > "$target_cfg" sed -n '/^actor_rollout_ref/,$p' "$tmp_cfg" >> "$target_cfg" - + rm "$tmp_cfg" "$tmp_header" - + echo "Generated: $target_cfg" } diff --git a/scripts/init_random_model.py b/scripts/init_random_model.py index 2bc3ffc1b80..432fffe2b30 100644 --- a/scripts/init_random_model.py +++ b/scripts/init_random_model.py @@ -87,10 +87,10 @@ def init_random_model(hf_model_path, new_config_path, output_path, trust_remote_ print(f"new_config: {new_confg}") if trust_remote_code: model = AutoModelForCausalLM.from_pretrained( - hf_model_path, config=new_confg, trust_remote_code=trust_remote_code + hf_model_path, config=new_confg, trust_remote_code=trust_remote_code, torch_dtype=new_confg.torch_dtype ) else: - model = AutoModelForCausalLM.from_config(new_confg) + model = AutoModelForCausalLM.from_config(new_confg, torch_dtype=new_confg.torch_dtype) model.save_pretrained(output_path) tokenizer.save_pretrained(output_path) new_confg.save_pretrained(output_path) diff --git a/scripts/install_sglang_mcore_npu.sh b/scripts/install_sglang_mcore_npu.sh index 2975db3d1ed..86678faed56 100644 --- a/scripts/install_sglang_mcore_npu.sh +++ b/scripts/install_sglang_mcore_npu.sh @@ -1,6 +1,7 @@ #!/bin/bash set -e NPU_DEVICE=${NPU_DEVICE:=A3} +USE_MEGATRON=${USE_MEGATRON:-1} export MAX_JOBS=32 diff --git a/setup.py b/setup.py index 9cde2eb2391..af51223f5d8 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,6 @@ ] TRL_REQUIRES = ["trl<=0.9.6"] MCORE_REQUIRES = ["mbridge"] -TRANSFERQUEUE_REQUIRES = ["TransferQueue==0.1.5"] extras_require = { "test": TEST_REQUIRES, @@ -70,7 +69,6 @@ "sglang": SGLANG_REQUIRES, "trl": TRL_REQUIRES, "mcore": MCORE_REQUIRES, - "transferqueue": TRANSFERQUEUE_REQUIRES, "trtllm": TRTLLM_REQUIRES, } @@ -92,7 +90,11 @@ extras_require=extras_require, package_data={ "": ["version/*"], - "verl": ["trainer/config/*.yaml"], + "verl": [ + "trainer/config/*.yaml", + "trainer/config/*/*.yaml", + "experimental/*/config/*.yaml", + ], }, include_package_data=True, long_description=long_description, diff --git a/tests/checkpoint_engine/test_correctness_on_gpu.py b/tests/checkpoint_engine/test_correctness_on_gpu.py index ff4a959b20f..05cf27cf4a2 100644 --- a/tests/checkpoint_engine/test_correctness_on_gpu.py +++ b/tests/checkpoint_engine/test_correctness_on_gpu.py @@ -22,12 +22,15 @@ RayResourcePool, split_resource_pool, ) +from verl.utils.device import get_device_name +from verl.utils.ray_utils import auto_await from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig @pytest.mark.asyncio @pytest.mark.parametrize("rebuild_group", [False, True]) @pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) +@auto_await async def test_nccl_checkpoint_engine( rebuild_group, num_trainer, @@ -64,7 +67,7 @@ async def test_nccl_checkpoint_engine( rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) # create checkpoint engine manager - checkpoint_manager = CheckpointEngineManager(backend="nccl", trainer=trainer, replicas=replicas) + checkpoint_manager = CheckpointEngineManager(config=checkpoint_engine_config, trainer=trainer, replicas=replicas) for _ in range(3): await checkpoint_manager.update_weights() rollout.check_weights() @@ -76,6 +79,7 @@ async def test_nccl_checkpoint_engine( @pytest.mark.asyncio @pytest.mark.parametrize("device", ["cuda", "cpu"]) @pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) +@auto_await async def test_nixl_checkpoint_engine( num_trainer, num_rollout, @@ -119,7 +123,55 @@ async def test_nixl_checkpoint_engine( rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) # create checkpoint engine manager - checkpoint_manager = CheckpointEngineManager(backend="nixl", trainer=trainer, replicas=replicas) + checkpoint_manager = CheckpointEngineManager(config=checkpoint_engine_config, trainer=trainer, replicas=replicas) + for _ in range(3): + await checkpoint_manager.update_weights() + rollout.check_weights() + + ray.shutdown() + + +@pytest.mark.skip(reason="temporary skip since our ci environment is not ready") +@pytest.mark.asyncio +@pytest.mark.parametrize("rebuild_group", [False]) +@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) +@auto_await +async def test_kimi_checkpoint_engine( + rebuild_group, + num_trainer, + num_rollout, + num_nodes=1, + num_gpus_per_node=8, + check_allclose=True, + model_path="~/models/Qwen/Qwen3-8B-Base", +): + model_path = os.path.expanduser(model_path) + ray.init( + runtime_env={ + "env_vars": { + "NCCL_IB_HCA": "mlx5", + "VERL_LOGGING_LEVEL": "DEBUG", + } + } + ) + + # initialize config + checkpoint_engine_config = CheckpointEngineConfig( + backend="kimi_ckpt_engine", engine_kwargs={"kimi_ckpt_engine": {"rebuild_group": rebuild_group}} + ) + model_config = HFModelConfig(path=model_path, use_remove_padding=True) + rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) + + # create trainer and rollout worker group + resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) + resource_pool.get_placement_groups(device_name=get_device_name()) + trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) + trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config) + trainer.reset() + rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) + + # create checkpoint engine manager + checkpoint_manager = CheckpointEngineManager(backend="kimi_ckpt_engine", trainer=trainer, replicas=replicas) for _ in range(3): await checkpoint_manager.update_weights() rollout.check_weights() diff --git a/tests/checkpoint_engine/test_correctness_on_npu.py b/tests/checkpoint_engine/test_correctness_on_npu.py index b99fcc771be..17f7dbe4b8e 100644 --- a/tests/checkpoint_engine/test_correctness_on_npu.py +++ b/tests/checkpoint_engine/test_correctness_on_npu.py @@ -23,12 +23,14 @@ split_resource_pool, ) from verl.utils.device import get_device_name +from verl.utils.ray_utils import auto_await from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig @pytest.mark.asyncio @pytest.mark.parametrize("rebuild_group", [False]) @pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) +@auto_await async def test_hccl_checkpoint_engine( rebuild_group, num_trainer, @@ -66,7 +68,54 @@ async def test_hccl_checkpoint_engine( rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) # create checkpoint engine manager - checkpoint_manager = CheckpointEngineManager(backend="hccl", trainer=trainer, replicas=replicas) + checkpoint_manager = CheckpointEngineManager(config=checkpoint_engine_config, trainer=trainer, replicas=replicas) + for _ in range(3): + await checkpoint_manager.update_weights() + rollout.check_weights() + + ray.shutdown() + + +@pytest.mark.skip(reason="temporary skip since our ci environment is not ready") +@pytest.mark.asyncio +@pytest.mark.parametrize("rebuild_group", [False]) +@pytest.mark.parametrize("num_trainer, num_rollout", [(4, 28)]) +async def test_kimi_checkpoint_engine( + rebuild_group, + num_trainer, + num_rollout, + num_nodes=2, + num_gpus_per_node=16, + check_allclose=True, + model_path="~/models/Qwen/Qwen3-32B", +): + model_path = os.path.expanduser(model_path) + ray.init( + runtime_env={ + "env_vars": { + "HCCL_CONNECT_TIMEOUT": "1500", + "VERL_LOGGING_LEVEL": "DEBUG", + } + } + ) + + # initialize config + checkpoint_engine_config = CheckpointEngineConfig( + backend="kimi_ckpt_engine", engine_kwargs={"kimi_ckpt_engine": {"rebuild_group": rebuild_group}} + ) + model_config = HFModelConfig(path=model_path, use_remove_padding=True) + rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) + + # create trainer and rollout worker group + resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) + resource_pool.get_placement_groups(device_name=get_device_name()) + trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) + trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config) + trainer.reset() + rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) + + # create checkpoint engine manager + checkpoint_manager = CheckpointEngineManager(backend="kimi_ckpt_engine", trainer=trainer, replicas=replicas) for _ in range(3): await checkpoint_manager.update_weights() rollout.check_weights() diff --git a/tests/checkpoint_engine/test_special_server_adapter.py b/tests/checkpoint_engine/test_special_server_adapter.py index 193a9eaeb56..3eb8a224f87 100644 --- a/tests/checkpoint_engine/test_special_server_adapter.py +++ b/tests/checkpoint_engine/test_special_server_adapter.py @@ -101,7 +101,7 @@ async def test_server_adapter(init_config): # 3. create checkpoint engine manager checkpoint_manager = CheckpointEngineManager( - backend=checkpoint_engine_config.backend, trainer=trainer, replicas=rollout_replicas + config=checkpoint_engine_config, trainer=trainer, replicas=rollout_replicas ) for i in range(3): await checkpoint_manager.update_weights() diff --git a/tests/checkpoint_engine/test_utils.py b/tests/checkpoint_engine/test_utils.py index 02e3c8f1031..b64ba9f776d 100644 --- a/tests/checkpoint_engine/test_utils.py +++ b/tests/checkpoint_engine/test_utils.py @@ -31,12 +31,13 @@ class TrainingWorkerTest(TrainingWorker): def __init__(self, config: TrainingWorkerConfig, checkpoint_engine_config: CheckpointEngineConfig) -> None: super().__init__(config) + backend = checkpoint_engine_config.backend bucket_size = checkpoint_engine_config.update_weights_bucket_megabytes << 20 engine_kwargs = checkpoint_engine_config.engine_kwargs.get(backend, {}) - self.checkpoint_engine = CheckpointEngineRegistry.new( - backend, is_master=(torch.distributed.get_rank() == 0), bucket_size=bucket_size, **engine_kwargs - ) + if torch.distributed.get_rank() == 0: + engine_kwargs["is_master"] = True + self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs) @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) async def update_weights(self): @@ -107,9 +108,11 @@ async def launch_servers(self): class CheckpointEngineWorkerTest(CheckpointEngineWorker): - def __init__(self, rollout_config: RolloutConfig, model_config: HFModelConfig, check_allclose: bool = True) -> None: + def __init__( + self, rollout_config: RolloutConfig, model_config: HFModelConfig, check_allclose: bool = True, *args, **kwargs + ) -> None: server_adapter = MockServerAdapter(rollout_config, model_config, check_allclose) - super().__init__(rollout_config, model_config, server_adapter) + super().__init__(rollout_config, model_config, server_adapter, *args, **kwargs) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def check_weights(self): diff --git a/tests/experimental/agent_loop/agent_utils.py b/tests/experimental/agent_loop/agent_utils.py index ad2c297f142..4596236bc78 100644 --- a/tests/experimental/agent_loop/agent_utils.py +++ b/tests/experimental/agent_loop/agent_utils.py @@ -21,14 +21,13 @@ from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role -from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker +from verl.utils import omega_conf_to_dataclass +from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup: # =========================== 1. Create hybrid ActorRollout workers =========================== - actor_rollout_cls = ( - AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker - ) + actor_rollout_cls = AsyncActorRolloutRefWorker role_worker_mapping = { Role.ActorRollout: ray.remote(actor_rollout_cls), } @@ -80,13 +79,13 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG config=config, rm_resource_pool=rm_resource_pool, ) - agent_loop_manager = AgentLoopManager( + agent_loop_manager = AgentLoopManager.create( config=config, worker_group=actor_rollout_wg, reward_loop_worker_handles=reward_loop_manager.reward_loop_workers, ) checkpoint_manager = CheckpointEngineManager( - backend=config.actor_rollout_ref.rollout.checkpoint_engine.backend, + config=omega_conf_to_dataclass(config.actor_rollout_ref.rollout.checkpoint_engine), trainer=actor_rollout_wg, replicas=agent_loop_manager.rollout_replicas, ) diff --git a/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py b/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py new file mode 100644 index 00000000000..e5d296a8756 --- /dev/null +++ b/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py @@ -0,0 +1,260 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np +import pytest +import torch +from omegaconf import OmegaConf + +from verl.experimental.agent_loop.agent_loop import ( + AgentLoopMetrics, + AgentLoopWorker, + DictConfigWrap, + _InternalAgentLoopOutput, +) +from verl.experimental.agent_loop.single_turn_agent_loop import SingleTurnAgentLoop +from verl.experimental.fully_async_policy.agent_loop.partial_single_turn_agent_loop import PartialSingleTurnAgentLoop +from verl.protocol import DataProto +from verl.utils.dataset.rl_dataset import RLHFDataset + + +@dataclass +class _FakeTokenOutput: + token_ids: list[int] + log_probs: Optional[list[float]] = None + routed_experts: Any = None + num_preempted: Optional[int] = None + + +class _FakeServerManager: + async def generate( + self, + request_id: str, + *, + prompt_ids: list[int], + sampling_params: dict[str, Any], + image_data: Optional[list[Any]] = None, + video_data: Optional[list[Any]] = None, + ) -> _FakeTokenOutput: + del request_id, sampling_params, image_data, video_data + # Return a short, deterministic "generation" for testing. + return _FakeTokenOutput(token_ids=prompt_ids[-1:] + [11, 12, 13], log_probs=[0.0, 0.0, 0.0, 0.0]) + + async def generate_for_partial( + self, + request_id: str, + *, + prompt_ids: list[int], + sampling_params: dict[str, Any], + image_data: Optional[list[Any]] = None, + video_data: Optional[list[Any]] = None, + ) -> tuple[list[int], list[float], bool]: + del request_id, sampling_params, image_data, video_data + # Return a short partial generation and "not cancelled". + response_ids = prompt_ids[-1:] + [21, 22] + response_logprobs = [0.0] * len(response_ids) + return response_ids, response_logprobs, False + + +class _FakeTokenizer: + def apply_chat_template( + self, + messages: list[dict[str, Any]], + *, + tools: Optional[list[dict]] = None, + add_generation_prompt: bool = True, + tokenize: bool = True, + **kwargs, + ) -> list[int]: + del messages, tools, add_generation_prompt, tokenize, kwargs + # Minimal tokenization: return a small prompt. + return [101, 102] + + def decode(self, ids: list[int] | torch.Tensor, skip_special_tokens: bool = True) -> str: + del ids, skip_special_tokens + return "" + + +def _pad_1d(ids: list[int], *, length: int, pad_id: int = 0) -> list[int]: + if len(ids) > length: + return ids[:length] + return ids + [pad_id] * (length - len(ids)) + + +def _to_internal( + *, + output_prompt_ids: list[int], + output_response_ids: list[int], + output_response_mask: list[int], + metrics: AgentLoopMetrics, + extra_fields: dict[str, Any], + num_turns: int, + prompt_len: int, + response_len: int, +) -> _InternalAgentLoopOutput: + prompt_ids = _pad_1d(output_prompt_ids, length=prompt_len, pad_id=0) + response_ids = _pad_1d(output_response_ids, length=response_len, pad_id=0) + response_mask = _pad_1d(output_response_mask, length=response_len, pad_id=0) + + seq_len = prompt_len + response_len + attention_mask = _pad_1d([1] * len(output_prompt_ids), length=prompt_len, pad_id=0) + _pad_1d( + [1] * len(output_response_ids), + length=response_len, + pad_id=0, + ) + input_ids = prompt_ids + response_ids + position_ids = list(range(seq_len)) + + def t(x: list[int]) -> torch.Tensor: + return torch.tensor([x], dtype=torch.long) + + return _InternalAgentLoopOutput( + prompt_ids=t(prompt_ids), + response_ids=t(response_ids), + response_mask=t(response_mask), + attention_mask=t(attention_mask), + input_ids=t(input_ids), + position_ids=t(position_ids), + response_logprobs=None, + routed_experts=None, + multi_modal_inputs=None, + multi_modal_data=None, + reward_score=None, + num_turns=num_turns, + metrics=metrics, + extra_fields=extra_fields, + ) + + +@pytest.mark.asyncio +async def test_agent_loop_extra_fields_schema_stable_for_training_concat_on_cpu(): + # Minimal config surface used by the agent loops. + config = OmegaConf.create( + { + "actor_rollout_ref": { + "rollout": {"prompt_length": 16, "response_length": 16, "multi_turn": {"tool_config_path": None}}, + "model": {}, + }, + "data": { + "tool_config_path": None, + "apply_chat_template_kwargs": {}, + }, + } + ) + + server_manager = _FakeServerManager() + tokenizer = _FakeTokenizer() + processor = None + + trainer_config = DictConfigWrap(config) + data_config = DictConfigWrap(config.data) + + single_turn = SingleTurnAgentLoop( + trainer_config=trainer_config, + server_manager=server_manager, + tokenizer=tokenizer, + processor=processor, + dataset_cls=RLHFDataset, + data_config=data_config, + ) + partial_single_turn = PartialSingleTurnAgentLoop( + trainer_config=trainer_config, + server_manager=server_manager, + tokenizer=tokenizer, + processor=processor, + dataset_cls=RLHFDataset, + data_config=data_config, + ) + + raw_prompt = [{"role": "user", "content": "hi"}] + sampling_params: dict[str, Any] = {} + + out_a = await single_turn.run(sampling_params=sampling_params, raw_prompt=raw_prompt) + out_b = await partial_single_turn.run(sampling_params=sampling_params, raw_prompt=raw_prompt, param_version=0) + + # Agent loop outputs should always contain these fields with consistent types. + assert out_a.extra_fields["turn_scores"] == [] + assert out_a.extra_fields["tool_rewards"] == [] + assert out_b.extra_fields["turn_scores"] == [] + assert out_b.extra_fields["tool_rewards"] == [] + + prompt_len = max(len(out_a.prompt_ids), len(out_b.prompt_ids)) + response_len = max(len(out_a.response_ids), len(out_b.response_ids)) + + internal_a = _to_internal( + output_prompt_ids=out_a.prompt_ids, + output_response_ids=out_a.response_ids, + output_response_mask=out_a.response_mask, + metrics=out_a.metrics, + extra_fields=out_a.extra_fields, + num_turns=out_a.num_turns, + prompt_len=prompt_len, + response_len=response_len, + ) + internal_b = _to_internal( + output_prompt_ids=out_b.prompt_ids, + output_response_ids=out_b.response_ids, + output_response_mask=out_b.response_mask, + metrics=out_b.metrics, + extra_fields=out_b.extra_fields, + num_turns=out_b.num_turns, + prompt_len=prompt_len, + response_len=response_len, + ) + + # Mimic two "worker chunks" and concatenate as in training. + dummy_worker = type("_DummyWorker", (), {"reward_loop_worker_handles": None})() + chunk_a = AgentLoopWorker._postprocess( + dummy_worker, + inputs=[internal_a], + input_non_tensor_batch={ + "index": np.array([0], dtype=object), + "agent_name": np.array(["single_turn_agent"], dtype=object), + }, + ) + chunk_b = AgentLoopWorker._postprocess( + dummy_worker, + inputs=[internal_b], + input_non_tensor_batch={ + "index": np.array([1], dtype=object), + "agent_name": np.array(["partial_single_turn_agent"], dtype=object), + }, + ) + merged: DataProto = DataProto.concat([chunk_a, chunk_b]) + + # Stable schema: present regardless of which loop produced a sample. + stable_keys = ( + "turn_scores", + "tool_rewards", + "is_cancel", + "param_version_start", + "param_version_end", + "extras", + ) + for key in stable_keys: + assert key in merged.non_tensor_batch, f"missing key in merged batch: {key}" + assert merged.non_tensor_batch[key].shape == (2,), ( + f"invalid shape for {key}: {merged.non_tensor_batch[key].shape}" + ) + + # And the list-typed fields are actually lists (not missing / scalar). + assert merged.non_tensor_batch["turn_scores"][0] == [] + assert merged.non_tensor_batch["tool_rewards"][0] == [] + assert merged.non_tensor_batch["turn_scores"][1] == [] + assert merged.non_tensor_batch["tool_rewards"][1] == [] diff --git a/tests/experimental/agent_loop/test_basic_agent_loop.py b/tests/experimental/agent_loop/test_basic_agent_loop.py index fcfb47ea9a6..6d675f6d724 100644 --- a/tests/experimental/agent_loop/test_basic_agent_loop.py +++ b/tests/experimental/agent_loop/test_basic_agent_loop.py @@ -28,6 +28,8 @@ from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema from verl.tools.schemas import ToolResponse from verl.utils import hf_tokenizer +from verl.utils.config import omega_conf_to_dataclass +from verl.workers.config import CheckpointEngineConfig @pytest.fixture @@ -345,8 +347,11 @@ def test_tool_agent_with_interaction(init_config): init_config.actor_rollout_ref.rollout.multi_turn.interaction_config_path = interaction_config_path init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2 agent_loop_manager = init_agent_loop_manager(init_config) + checkpoint_engine_config = omega_conf_to_dataclass( + init_config.actor_rollout_ref.rollout.checkpoint_engine, CheckpointEngineConfig + ) checkpoint_manager = CheckpointEngineManager( - backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend, + config=checkpoint_engine_config, trainer=agent_loop_manager.worker_group, replicas=agent_loop_manager.rollout_replicas, ) diff --git a/tests/experimental/agent_loop/test_standalone_rollout.py b/tests/experimental/agent_loop/test_standalone_rollout.py index 96b7912045b..66ff31dae97 100644 --- a/tests/experimental/agent_loop/test_standalone_rollout.py +++ b/tests/experimental/agent_loop/test_standalone_rollout.py @@ -21,6 +21,7 @@ from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager from verl.checkpoint_engine import CheckpointEngineManager +from verl.utils import omega_conf_to_dataclass from verl.workers.rollout.replica import get_rollout_replica_class @@ -124,7 +125,7 @@ def test_hybrid_rollout_with_ep(init_config): # - sleep rollout and load FSDP model and optimizer agent_loop_manager = init_agent_loop_manager(init_config) checkpoint_manager = CheckpointEngineManager( - backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend, + config=omega_conf_to_dataclass(init_config.actor_rollout_ref.rollout.checkpoint_engine), trainer=agent_loop_manager.worker_group, replicas=agent_loop_manager.rollout_replicas, ) diff --git a/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py b/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py index 8c5174da2d5..0ea96dca409 100644 --- a/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py +++ b/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py @@ -25,6 +25,7 @@ from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup from verl.trainer.main_ppo import create_rl_sampler from verl.trainer.ppo.ray_trainer import ResourcePoolManager +from verl.utils import omega_conf_to_dataclass from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn from verl.utils.device import get_device_name from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker @@ -97,10 +98,13 @@ def test_agent_reward_loop_standalone(): ) actor_rollout_wg.init_model() - agent_loop_manager = AgentLoopManager(config, worker_group=actor_rollout_wg) + agent_loop_manager = AgentLoopManager.create( + config=config, + worker_group=actor_rollout_wg, + ) # sleep rollout replicas checkpoint_manager = CheckpointEngineManager( - backend=config.actor_rollout_ref.rollout.checkpoint_engine.backend, + config=omega_conf_to_dataclass(config.actor_rollout_ref.rollout.checkpoint_engine), trainer=actor_rollout_wg, replicas=agent_loop_manager.rollout_replicas, ) diff --git a/tests/experimental/reward_loop/test_agent_reward_loop_standalone.py b/tests/experimental/reward_loop/test_agent_reward_loop_standalone.py index bd9011b9874..80a0945bec7 100644 --- a/tests/experimental/reward_loop/test_agent_reward_loop_standalone.py +++ b/tests/experimental/reward_loop/test_agent_reward_loop_standalone.py @@ -56,6 +56,7 @@ def test_agent_reward_loop_standalone(): config.actor_rollout_ref.rollout.prompt_length = 1024 config.actor_rollout_ref.rollout.response_length = 4096 config.actor_rollout_ref.rollout.skip_tokenizer_init = True + config.actor_rollout_ref.rollout.nnodes = 1 config.trainer.n_gpus_per_node = 4 config.trainer.nnodes = 1 @@ -76,8 +77,9 @@ def test_agent_reward_loop_standalone(): # 1. init reward model manager reward_loop_manager = RewardLoopManager(config) - agent_loop_manager = AgentLoopManager( - config=config, reward_loop_worker_handles=reward_loop_manager.reward_loop_workers + agent_loop_manager = AgentLoopManager.create( + config=config, + reward_loop_worker_handles=reward_loop_manager.reward_loop_workers, ) # 2. init test data diff --git a/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh b/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh index 9f36a9dc860..bd081fd88cf 100644 --- a/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh +++ b/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh @@ -22,4 +22,4 @@ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ trainer.nnodes=1 \ actor_rollout_ref.rollout.name=hf \ trainer.use_legacy_worker_impl=disable \ - trainer.total_training_steps=2 \ No newline at end of file + trainer.total_training_steps=2 diff --git a/tests/special_e2e/run_fully_async_policy.sh b/tests/special_e2e/run_fully_async_policy.sh index cf372b618a2..01d807ba63a 100644 --- a/tests/special_e2e/run_fully_async_policy.sh +++ b/tests/special_e2e/run_fully_async_policy.sh @@ -133,6 +133,9 @@ common_params=( async_training.staleness_threshold=${staleness_threshold} async_training.partial_rollout="${partial_rollout}" async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" + # GPU specific configurations + actor_rollout_ref.rollout.checkpoint_engine.backend='nccl' + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=1024 ) if [ "${ACTOR_STRATEGY}" == "fsdp2" ]; then @@ -147,7 +150,7 @@ if [ "${ACTOR_STRATEGY}" == "fsdp2" ]; then python3 -m verl.experimental.fully_async_policy.fully_async_main \ "${common_params[@]}" \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.grad_clip=1.0 \ actor_rollout_ref.model.use_remove_padding=True \ diff --git a/tests/special_e2e/run_one_step_off_policy.sh b/tests/special_e2e/run_one_step_off_policy.sh index 895056624ee..bdcba5caaaf 100755 --- a/tests/special_e2e/run_one_step_off_policy.sh +++ b/tests/special_e2e/run_one_step_off_policy.sh @@ -88,8 +88,10 @@ common_params=( actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} actor_rollout_ref.rollout.val_kwargs.do_sample=True actor_rollout_ref.rollout.val_kwargs.n=1 - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.enable_chunked_prefill=True + actor_rollout_ref.rollout.name=vllm + actor_rollout_ref.rollout.checkpoint_engine.backend='nccl' + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=1024 reward.reward_manager.name=dapo +reward.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} +reward.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} @@ -130,7 +132,7 @@ if [ "${ACTOR_STRATEGY}" == "fsdp2" ]; then python3 -m verl.experimental.one_step_off_policy.main_ppo \ "${common_params[@]}" \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.grad_clip=1.0 \ actor_rollout_ref.model.use_remove_padding=True \ diff --git a/tests/special_e2e/run_ppo_trainer_torchtitan.sh b/tests/special_e2e/run_ppo_trainer_torchtitan.sh new file mode 100644 index 00000000000..1ce2822d60d --- /dev/null +++ b/tests/special_e2e/run_ppo_trainer_torchtitan.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# Download model if not exists +MODEL_ID=${MODEL_ID:-Qwen/Qwen3-0.6B} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +#huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" + +VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False} +NUM_GPUS=${NUM_GPUS:-1} +FSDP_SIZE=${FSDP_SIZE:-1} +TP_SIZE=${TP_SIZE:-1} +EP_SIZE=${EP_SIZE:-1} +VERL_EXP_NAME=${VERL_EXP_NAME:-qwen3-0.6b-function-reward-minimal-fsdp-size1} + +python3 -m verl.trainer.main_ppo \ + model_engine=torchtitan \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=256 \ + data.seed=42 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.min_lr_factor=1.0 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.torchtitan.data_parallel_shard_size="${FSDP_SIZE}" \ + actor_rollout_ref.actor.torchtitan.tensor_parallel_size="${TP_SIZE}" \ + actor_rollout_ref.actor.torchtitan.expert_parallel_size="${EP_SIZE}" \ + actor_rollout_ref.actor.torchtitan.attn_type=flex \ + actor_rollout_ref.actor.torchtitan.use_torch_compile=False \ + actor_rollout_ref.actor.torchtitan.param_offload=False \ + actor_rollout_ref.actor.torchtitan.optimizer_offload=False \ + actor_rollout_ref.ref.torchtitan.use_torch_compile=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.n=5 \ + critic.optim.lr=1e-5 \ + critic.model.path="${MODEL_PATH}" \ + critic.ppo_micro_batch_size_per_gpu=4 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.use_legacy_worker_impl=disable \ + trainer.logger=['console','file','wandb'] \ + trainer.project_name='verl_grpo_example_gsm8k_0217' \ + trainer.experiment_name="${VERL_EXP_NAME}" \ + trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ + trainer.n_gpus_per_node="${NUM_GPUS}" \ + trainer.nnodes=1 \ + trainer.total_training_steps=100 $@ diff --git a/tests/special_e2e/run_ppo_trainer_veomni.sh b/tests/special_e2e/run_ppo_trainer_veomni.sh index 2e5c10d5e97..03eb2612417 100644 --- a/tests/special_e2e/run_ppo_trainer_veomni.sh +++ b/tests/special_e2e/run_ppo_trainer_veomni.sh @@ -38,7 +38,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.actor.use_torch_compile=False \ - actor_rollout_ref.actor.veomni.data_parallel_size="${FSDP_SIZE}" \ + actor_rollout_ref.actor.veomni.fsdp_size="${FSDP_SIZE}" \ actor_rollout_ref.actor.veomni.ulysses_parallel_size="${SP_SIZE}" \ actor_rollout_ref.actor.veomni.expert_parallel_size="${EP_SIZE}" \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ diff --git a/tests/special_e2e/sft/run_sft.sh b/tests/special_e2e/sft/run_sft.sh index 4cef7c68082..4078ff260ef 100644 --- a/tests/special_e2e/sft/run_sft.sh +++ b/tests/special_e2e/sft/run_sft.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -xeuo pipefail -ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.fsdp_sft_trainer"} +ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"} NUM_GPUS=${NUM_GPUS:-8} @@ -9,8 +9,8 @@ MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} #hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}" -TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} -VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} +TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k_sft/train.parquet} +VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k_sft/test.parquet} SP_SIZE=${SP_SIZE:-1} LIGER=${LIGER:-False} @@ -34,28 +34,23 @@ mkdir -p "${ckpts_home}" torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \ data.train_files="${TRAIN_FILES}" \ data.val_files="${VAL_FILES}" \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - data.prompt_dict_keys=['question'] \ - data.response_dict_keys=['answer'] \ - data.multiturn.enable="${MULTITURN}" \ - data.multiturn.messages_key=messages \ - optim.lr=1e-4 \ + data.messages_key=messages \ data.micro_batch_size_per_gpu=${micro_bsz} \ - model.strategy=fsdp \ - model.partial_pretrain="${MODEL_PATH}" \ + optim.lr=1e-4 \ + engine=fsdp \ + engine.ulysses_sequence_parallel_size="${SP_SIZE}" \ + model.path="${MODEL_PATH}" \ model.lora_rank="${LORA_RANK}" \ model.lora_alpha=16 \ model.target_modules=all-linear \ model.use_liger="${LIGER}" \ - ulysses_sequence_parallel_size="${SP_SIZE}" \ - use_remove_padding="${RM_PAD}" \ + model.use_remove_padding="${RM_PAD}" \ trainer.default_local_dir="${ckpts_home}" \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.total_training_steps=${TOTAL_TRAIN_STEP} \ trainer.save_freq=${SAVE_FREQ} \ - trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \ + checkpoint.save_contents=[model,optimizer,extra,hf_model] \ trainer.max_ckpt_to_keep=1 \ trainer.resume_mode=${RESUME_MODE} \ trainer.logger=['console'] $@ diff --git a/tests/special_e2e/sft/run_sft_engine.sh b/tests/special_e2e/sft/run_sft_engine.sh index 12ef3c2bfed..9fe80afae13 100644 --- a/tests/special_e2e/sft/run_sft_engine.sh +++ b/tests/special_e2e/sft/run_sft_engine.sh @@ -30,7 +30,7 @@ MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} #hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}" SP_SIZE=${SP_SIZE:-1} -FSDP_SIZE=${FSDP_SIZE:-${NUM_GPUS}} +FSDP_SIZE=${FSDP_SIZE:-1} FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp"} TP_SIZE=${TP_SIZE:-1} @@ -44,6 +44,8 @@ USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True} FSDP_ENGINE_CONFIG="\ engine=${backend} \ + model=hf_model \ + model.path=$MODEL_PATH \ optim=${backend} \ optim.lr=1e-5 \ optim.lr_warmup_steps_ratio=0.2 \ @@ -58,6 +60,8 @@ FSDP_ENGINE_CONFIG="\ VEOMNI_ENGINE_CONFIG="\ engine=${backend} \ + model=hf_model \ + model.path=$MODEL_PATH \ optim=${backend} \ optim.lr=1e-5 \ optim.lr_warmup_steps_ratio=0.2 \ @@ -67,11 +71,12 @@ VEOMNI_ENGINE_CONFIG="\ optim.lr_min=1e-6 \ optim.lr_scheduler_type=cosine \ engine.ulysses_parallel_size=${SP_SIZE} \ - engine.data_parallel_size=${FSDP_SIZE}" - + engine.fsdp_size=${FSDP_SIZE}" MEGATRON_ENGINE_CONFIG="\ engine=${backend} \ + model=hf_model \ + model.path=$MODEL_PATH \ optim=${backend} \ optim.lr=1e-5 \ optim.lr_warmup_steps_ratio=0.2 \ @@ -88,6 +93,26 @@ MEGATRON_ENGINE_CONFIG="\ +engine.override_transformer_config.context_parallel_size=${CP_SIZE} \ engine.use_mbridge=True" +TORCHTITAN_ENGINE_CONFIG="\ + engine=${backend} \ + model=hf_model \ + model.path=${MODEL_PATH} \ + optim=${backend} \ + optim.lr=1e-5 \ + optim.lr_warmup_steps_ratio=0.2 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.min_lr_factor=0.1 \ + optim.decay_type=cosine \ + optim.total_training_steps=1000 \ + engine.tensor_parallel_size=${TP_SIZE} \ + engine.pipeline_parallel_size=${PP_SIZE} \ + engine.context_parallel_size=${CP_SIZE} \ + engine.data_parallel_shard_size=${FSDP_SIZE} \ + engine.use_torch_compile=False" + + if [ "$backend" = "fsdp" ]; then ENGINE_CONFIG="$FSDP_ENGINE_CONFIG" echo "Using fsdp engine" @@ -96,6 +121,10 @@ elif [ "$backend" = "veomni" ]; then ENGINE_CONFIG="$VEOMNI_ENGINE_CONFIG" echo "Using veomni engine" exp_name=gsm8k-${backend}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode} +elif [ "$backend" = "torchtitan" ]; then + ENGINE_CONFIG="$TORCHTITAN_ENGINE_CONFIG" + echo "Using torchtitan engine" + exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-cp${CP_SIZE}-dp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode} else ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG" echo "Using megatron engine" @@ -113,8 +142,8 @@ $COMMAND \ data.use_dynamic_bsz=True \ data.max_token_len_per_gpu=2048 \ data.messages_key=messages \ - model.path=$MODEL_PATH \ model.use_remove_padding=${USE_REMOVE_PADDING} \ + data.ignore_input_ids_mismatch=True \ ${ENGINE_CONFIG} \ trainer.test_freq=after_each_epoch \ trainer.save_freq=-1 \ @@ -129,5 +158,5 @@ $COMMAND \ # trainer.total_training_steps=${TOTAL_TRAIN_STEP} \ # trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \ # trainer.max_ckpt_to_keep=1 \ - -rm -rf "${ckpts_home:?}/*" \ No newline at end of file + +rm -rf "${ckpts_home:?}/*" diff --git a/tests/special_e2e/sft/test_sft_engine_all.sh b/tests/special_e2e/sft/test_sft_engine_all.sh index 96f5f195692..21524ce1d09 100644 --- a/tests/special_e2e/sft/test_sft_engine_all.sh +++ b/tests/special_e2e/sft/test_sft_engine_all.sh @@ -37,6 +37,15 @@ BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 b echo "run with tp2 pp2 vpp2 cp2 num_gpus8 mode=ray" BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine.sh +# TODO: Will add back torchtitan CI once everything is ready +# # test with torchtitan fsdp=2 +# echo "run with tp1 pp1 cp1 fsdp2 num_gpus2" +# BACKEND=torchtitan TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=2 bash tests/special_e2e/sft/run_sft_engine.sh + +# # test with torchtitan tp2 fsdp=2 +# echo "run with tp2 pp1 cp1 fsdp2 num_gpus4" +# BACKEND=torchtitan TP_SIZE=2 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=4 bash tests/special_e2e/sft/run_sft_engine.sh + python3 tests/special_e2e/sft/compare_sft_engine_results.py rm -rf ~/verl/test/log diff --git a/tests/special_e2e/sft/test_sp_loss_match.py b/tests/special_e2e/sft/test_sp_loss_match.py deleted file mode 100644 index 5d8e59e721d..00000000000 --- a/tests/special_e2e/sft/test_sp_loss_match.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.distributed -from tensordict import TensorDict -from torch.distributed.device_mesh import init_device_mesh - -from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer -from verl.utils.distributed import initialize_global_process_group - - -def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4): - """Test consistency between original forward pass and SP+rmpad forward passes. - - Args: - trainer: The FSDPSFTTrainer instance to test - total_steps: Number of steps to test (default: 4) - """ - if trainer.device_mesh.get_rank() == 0: - print("\nStarting debug comparison between original and SP+rmpad forward passes...") - print(f"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}") - print(f"Remove padding: {trainer.use_remove_padding}\n") - - steps_remaining = total_steps - - for epoch in range(1): # Just one epoch for testing - trainer.train_sampler.set_epoch(epoch=epoch) - for data in trainer.train_dataloader: - data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda() - trainer.fsdp_model.train() - micro_batches = data.split(trainer.config.data.micro_batch_size_per_gpu) - - for idx, micro_batch in enumerate(micro_batches): - if trainer.device_mesh.get_rank() == 0: - print(f"\nProcessing micro batch {idx + 1}/{len(micro_batches)}") - - # Compute losses using both methods - # Disable SP and rmpad - trainer.use_remove_padding = False - old_sp = trainer.config.ulysses_sequence_parallel_size - trainer.config.ulysses_sequence_parallel_size = 1 - loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) - - # Do SP and rmpad - trainer.config.ulysses_sequence_parallel_size = old_sp - trainer.use_remove_padding = True - loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) - - # Collect losses across all ranks - loss_ref_all = loss_ref.clone() - loss_sp_all = loss_sp.clone() - torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG) - torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG) - - # Calculate relative difference of averaged losses - rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8) - - if trainer.device_mesh.get_rank() == 0: - print("\nComparison Results (Averaged across ranks):") - print(f"Reference Loss: {loss_ref_all.item():.6f}") - print(f"SP+rmpad Loss: {loss_sp_all.item():.6f}") - print(f"Relative Difference: {rel_diff.item():.6f}") - - assert rel_diff.item() < 1e-2, "Significant difference detected between averaged losses!" - print("Loss difference is within the acceptable range.") - - steps_remaining -= 1 - if steps_remaining == 0: - break - if steps_remaining == 0: - break - break - - if trainer.device_mesh.get_rank() == 0: - print("\nDebug comparison completed successfully.") - - -def create_trainer(config): - """Create and initialize a trainer instance with the given config. - - Args: - config: Configuration object with training parameters - - Returns: - FSDPSFTTrainer: Initialized trainer instance - """ - local_rank, rank, world_size = initialize_global_process_group() - - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) - - dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh( - device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp") - ) - - # build tokenizer and datasets first - from verl.trainer.fsdp_sft_trainer import create_sft_dataset - from verl.utils import hf_tokenizer - from verl.utils.fs import copy_to_local - - local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) - tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) - train_dataset = create_sft_dataset( - config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1) - ) - val_dataset = create_sft_dataset( - config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1) - ) - - return FSDPSFTTrainer( - config=config, - device_mesh=device_mesh, - ulysses_device_mesh=ulysses_device_mesh, - tokenizer=tokenizer, - train_dataset=train_dataset, - val_dataset=val_dataset, - ) - - -def main(config): - """Main function to run trainer tests. - - Args: - config: Configuration object with training parameters - """ - trainer = create_trainer(config) - test_trainer_forward_consistency(trainer) - - -if __name__ == "__main__": - import hydra - from omegaconf import DictConfig - - @hydra.main(config_path="../../../verl/trainer/config", config_name="sft_trainer") - def hydra_entry(cfg: DictConfig) -> None: - main(cfg) - - hydra_entry() diff --git a/tests/special_e2e/run_transferqueue.sh b/tests/special_npu/run_fully_async_policy.sh similarity index 59% rename from tests/special_e2e/run_transferqueue.sh rename to tests/special_npu/run_fully_async_policy.sh index b76039d7bb0..fa517e81ae4 100644 --- a/tests/special_e2e/run_transferqueue.sh +++ b/tests/special_npu/run_fully_async_policy.sh @@ -1,13 +1,18 @@ #!/usr/bin/env bash set -xeuo pipefail +# Test script for fully_async_policy E2E regression testing +# This script runs fully async PPO training with both FSDP2 and Megatron backends +# to ensure the asynchronous training mechanism works correctly NUM_GPUS=${NUM_GPUS:-8} -ACTOR_STRATEGY=${ACTOR_STRATEGY:-"fsdp"} # fsdp or megatron +ACTOR_STRATEGY=${ACTOR_STRATEGY:-"fsdp2"} # fsdp2 or megatron # Download model if not exists MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} -MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +MODEL_PATH=${MODEL_PATH:-${HOME}/.cache/models/${MODEL_ID}} +# hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}" + rollout_mode="async" rollout_name="vllm" # sglang or vllm @@ -28,8 +33,8 @@ clip_ratio_low=0.2 clip_ratio_high=0.28 # Response length parameters -max_prompt_length=512 -max_response_length=1024 +max_prompt_length=1024 +max_response_length=2048 enable_overlong_buffer=True overlong_buffer_len=128 overlong_penalty_factor=1.0 @@ -43,61 +48,58 @@ top_p=1.0 top_k=-1 val_top_p=0.7 -n_gpus_training=8 -train_prompt_bsz=128 -val_prompt_bsz=128 -n_resp_per_prompt=5 -train_prompt_mini_bsz=32 -test_freq=-1 +# Fully async specific parameters +n_gpus_rollout=4 +n_gpus_training=4 -log_dir="./logs" -mkdir -p $log_dir -timestamp=$(date +"%Y%m%d%H%M%S") -log_file="${log_dir}/qwen2_5-0_5b_transferqueue_${timestamp}.log" +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=16 +total_rollout_steps=$(((128))) +test_freq=-1 +staleness_threshold=0.1 +trigger_parameter_sync_step=4 +partial_rollout=True -exp_name="$(basename "${MODEL_ID,}")-transferqueue-${ACTOR_STRATEGY}-minimal" +exp_name="$(basename "${MODEL_ID,,}")-fully-async-policy-${ACTOR_STRATEGY}-minimal" -echo "Running transferqueue with ${ACTOR_STRATEGY} strategy" -echo "Total GPUs: ${NUM_GPUS}" +echo "Running fully_async_policy with ${ACTOR_STRATEGY} strategy" +echo "Total GPUs: ${NUM_GPUS}, Rollout GPUs: ${n_gpus_rollout}, Training GPUs: ${n_gpus_training}" -# Common parameters for both FSDP and Megatron +# Common parameters for both FSDP2 and Megatron common_params=( data.train_files="${HOME}/data/gsm8k/train.parquet" data.val_files="${HOME}/data/gsm8k/test.parquet" data.prompt_key=prompt - data.truncation='error' + data.truncation='left' data.max_prompt_length=${max_prompt_length} data.max_response_length=${max_response_length} - data.filter_overlong_prompts_workers=128 - data.filter_overlong_prompts=True data.train_batch_size=${train_prompt_bsz} - data.val_batch_size=${val_prompt_bsz} + data.gen_batch_size=${gen_prompt_bsz} data.return_raw_chat=${return_raw_chat} actor_rollout_ref.rollout.n=${n_resp_per_prompt} + actor_rollout_ref.rollout.calculate_log_probs=True algorithm.adv_estimator=${adv_estimator} algorithm.use_kl_in_reward=${use_kl_in_reward} algorithm.kl_ctrl.kl_coef=${kl_coef} + actor_rollout_ref.hybrid_engine=False actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} actor_rollout_ref.actor.clip_ratio_c=10.0 - actor_rollout_ref.actor.use_kl_loss=True actor_rollout_ref.model.path="${MODEL_PATH}" actor_rollout_ref.actor.optim.lr=1e-6 actor_rollout_ref.actor.optim.lr_warmup_steps=-1 actor_rollout_ref.actor.optim.weight_decay=0.1 actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 actor_rollout_ref.actor.entropy_coeff=0 actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 + actor_rollout_ref.rollout.gpu_memory_utilization=0.70 actor_rollout_ref.rollout.temperature=${temperature} actor_rollout_ref.rollout.top_p=${top_p} actor_rollout_ref.rollout.top_k=${top_k} - actor_rollout_ref.rollout.max_num_batched_tokens=10240 actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} @@ -106,43 +108,51 @@ common_params=( actor_rollout_ref.rollout.enable_chunked_prefill=True actor_rollout_ref.rollout.name=${rollout_name} actor_rollout_ref.rollout.mode=${rollout_mode} - actor_rollout_ref.rollout.disable_log_stats=True - trainer.logger=console - trainer.project_name='verl-test-transferqueue' + actor_rollout_ref.rollout.disable_log_stats=False + reward.reward_manager.name=dapo + +reward.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} + +reward.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} + +reward.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} + +reward.reward_kwargs.overlong_buffer_cfg.log=False + +reward.reward_kwargs.max_resp_len=${max_response_length} + trainer.logger=['console'] + trainer.project_name='verl-test-fully-async' trainer.experiment_name="${exp_name}" - trainer.test_freq="${test_freq}" + trainer.val_before_train=True trainer.save_freq=-1 trainer.resume_mode=disable trainer.nnodes=1 trainer.n_gpus_per_node=${n_gpus_training} - trainer.total_training_steps=2 - trainer.total_epochs=15 - trainer.val_before_train=True + trainer.log_val_generations=10 + rollout.nnodes=1 + rollout.n_gpus_per_node=${n_gpus_rollout} + rollout.total_rollout_steps=${total_rollout_steps} + rollout.total_epochs=2 + rollout.test_freq=${test_freq} + # Fully async specific configurations + async_training.staleness_threshold=${staleness_threshold} + async_training.partial_rollout="${partial_rollout}" + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" + # NPU specific configurations + trainer.device='npu' + actor_rollout_ref.rollout.checkpoint_engine.backend='hccl' + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=1024 ) - # Detect device - device_name=$(python3 - <<'EOF' -from verl.utils.device import get_device_name -print(get_device_name()) -EOF -) - -if [ "${ACTOR_STRATEGY}" == "fsdp" ]; then - echo "Running TransferQueue training with FSDP strategy..." - # FSDP specific parameters; fsdp_size need to be -1 +if [ "${ACTOR_STRATEGY}" == "fsdp2" ]; then + echo "Running fully async training with FSDP2 strategy..." + # FSDP2 specific parameters gen_tp=1 sp_size=1 - fsdp_size=-1 + fsdp_size=1 ref_offload=True actor_offload=False - python3 -m verl.experimental.transfer_queue.main_ppo \ - --config-path=config \ - --config-name='transfer_queue_ppo_trainer' \ + python3 -m verl.experimental.fully_async_policy.fully_async_main \ "${common_params[@]}" \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.strategy=fsdp \ - critic.strategy=fsdp \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ + critic.strategy=fsdp2 \ actor_rollout_ref.actor.grad_clip=1.0 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.use_dynamic_bsz=True \ @@ -154,11 +164,10 @@ if [ "${ACTOR_STRATEGY}" == "fsdp" ]; then actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ - 2>&1 | tee "$log_file" $@ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} $@ elif [ "${ACTOR_STRATEGY}" == "megatron" ]; then - echo "Running TransferQueue training with Megatron strategy..." + echo "Running fully async training with Megatron strategy..." # Megatron specific parameters gen_tp=2 train_tp=1 @@ -166,25 +175,16 @@ elif [ "${ACTOR_STRATEGY}" == "megatron" ]; then ref_offload=True actor_offload=False - extra_flash_args=() - - if [ "$device_name" == "npu" ]; then - echo "Detect NPU device, enabling FlashAttention..." - extra_flash_args+=( - ++actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True - ) - fi - - # For Ascend NPU, please add: - #++actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True \ - #++actor_rollout_ref.ref.megatron.override_transformer_config.use_flash_attn=True \ - python3 -m verl.experimental.transfer_queue.main_ppo \ + python3 -m verl.experimental.fully_async_policy.fully_async_main \ --config-path=config \ - --config-name='transfer_queue_ppo_megatron_trainer' \ + --config-name='fully_async_ppo_megatron_trainer.yaml' \ "${common_params[@]}" \ actor_rollout_ref.actor.strategy=megatron \ critic.strategy=megatron \ actor_rollout_ref.actor.optim.lr_decay_steps=10000000 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \ actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \ actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \ @@ -193,12 +193,11 @@ elif [ "${ACTOR_STRATEGY}" == "megatron" ]; then actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.ref.megatron.param_offload=${ref_offload} \ - "${extra_flash_args[@]}" \ - 2>&1 | tee "$log_file" $@ + actor_rollout_ref.ref.megatron.param_offload=${ref_offload} $@ else - echo "Error: Unknown strategy ${ACTOR_STRATEGY}. Please use 'fsdp' or 'megatron'" + echo "Error: Unknown strategy ${ACTOR_STRATEGY}. Please use 'fsdp2' or 'megatron'" exit 1 fi -echo "TransferQueue test completed successfully with ${ACTOR_STRATEGY} strategy" \ No newline at end of file +echo "Fully async policy E2E test completed successfully with ${ACTOR_STRATEGY} strategy" + diff --git a/tests/special_npu/run_one_step_off_policy.sh b/tests/special_npu/run_one_step_off_policy.sh index f9cbc89969c..2426a380fec 100644 --- a/tests/special_npu/run_one_step_off_policy.sh +++ b/tests/special_npu/run_one_step_off_policy.sh @@ -77,7 +77,7 @@ common_params=( actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} actor_rollout_ref.actor.entropy_coeff=0 actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 + actor_rollout_ref.rollout.gpu_memory_utilization=0.70 actor_rollout_ref.rollout.temperature=${temperature} actor_rollout_ref.rollout.top_p=${top_p} actor_rollout_ref.rollout.top_k=${top_k} @@ -86,9 +86,11 @@ common_params=( actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} actor_rollout_ref.rollout.val_kwargs.do_sample=True actor_rollout_ref.rollout.val_kwargs.n=1 - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.name=vllm \ - +actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode="FULL_AND_PIECEWISE" \ + actor_rollout_ref.rollout.enable_chunked_prefill=True + actor_rollout_ref.rollout.name=vllm + actor_rollout_ref.rollout.checkpoint_engine.backend='hccl' + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=1024 + +actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode="FULL_AND_PIECEWISE" reward.reward_manager.name=dapo +reward.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} +reward.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} @@ -120,7 +122,7 @@ actor_offload=False python3 -m verl.experimental.one_step_off_policy.main_ppo \ "${common_params[@]}" \ - actor_rollout_ref.actor.strategy=$ACTOR_STRATEGY \ + actor_rollout_ref.actor.fsdp_config.strategy=$ACTOR_STRATEGY \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.grad_clip=1.0 \ actor_rollout_ref.model.use_remove_padding=True \ diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py index b743913b77c..8310583c631 100644 --- a/tests/special_sanity/check_device_api_usage.py +++ b/tests/special_sanity/check_device_api_usage.py @@ -42,6 +42,7 @@ "verl/workers/engine/utils.py", # appear in enable_full_determinism "verl/workers/engine/fsdp/transformer_impl.py", # appear in default device_name "verl/workers/engine/veomni/transformer_impl.py", # appear in default device_name + "verl/workers/engine/torchtitan/transformer_impl.py", # appear in default device_name "verl/workers/rollout/vllm_rollout/vllm_async_server.py", # appear in config.cudagraph_capture_sizes "verl/workers/rollout/sglang_rollout/async_sglang_server.py", # manually set CUDA_VISIBLE_DEVICES "verl/workers/rollout/trtllm_rollout/trtllm_async_server.py", # appear in config.cudagraph_capture_sizes diff --git a/tests/utils/dataset/test_sft_dataset_on_cpu.py b/tests/utils/dataset/test_sft_dataset_on_cpu.py deleted file mode 100644 index be91b598091..00000000000 --- a/tests/utils/dataset/test_sft_dataset_on_cpu.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -from verl.utils import hf_tokenizer -from verl.utils.dataset.sft_dataset import SFTDataset - - -def get_gsm8k_data(): - # prepare test dataset - local_folder = os.path.expanduser("~/data/gsm8k/") - local_path = os.path.join(local_folder, "train.parquet") - return local_path - - -def test_sft_cot_dataset(): - tokenizer = hf_tokenizer(os.path.expanduser("~/models/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")) - local_path = get_gsm8k_data() - from omegaconf import OmegaConf - - dataset = SFTDataset( - parquet_files=local_path, - tokenizer=tokenizer, - config=OmegaConf.create( - { - "prompt_key": "prompt", - "prompt_dict_keys": ["content"], - "response_key": "extra_info", - "response_dict_keys": ["answer"], - "max_length": 512, - } - ), - ) - - data = dataset[0]["input_ids"] - output = tokenizer.batch_decode([data])[0] - assert len(output) > 1 - assert isinstance(output, str) - - -def test_sft_dataset(): - tokenizer = hf_tokenizer(os.path.expanduser("~/models/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")) - local_path = get_gsm8k_data() - from omegaconf import OmegaConf - - dataset = SFTDataset( - parquet_files=local_path, - tokenizer=tokenizer, - config=OmegaConf.create( - { - "prompt_key": "extra_info", - "prompt_dict_keys": ["question"], - "response_key": "extra_info", - "response_dict_keys": ["answer"], - "max_length": 512, - } - ), - ) - - data = dataset[0]["input_ids"] - output = tokenizer.batch_decode([data])[0] - assert len(output) > 1 - assert isinstance(output, str) - - -def test_sft_dataset_with_max_samples(): - tokenizer = hf_tokenizer(os.path.expanduser("~/models/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")) - local_path = get_gsm8k_data() - from omegaconf import OmegaConf - - dataset = SFTDataset( - parquet_files=local_path, - tokenizer=tokenizer, - config=OmegaConf.create( - { - "prompt_key": "extra_info", - "prompt_dict_keys": ["question"], - "response_key": "extra_info", - "response_dict_keys": ["answer"], - "max_length": 512, - } - ), - max_samples=5, - ) - - assert len(dataset) == 5 diff --git a/tests/utils/test_linear_cross_entropy.py b/tests/utils/test_linear_cross_entropy.py index 0512d1376de..801eaff27c5 100644 --- a/tests/utils/test_linear_cross_entropy.py +++ b/tests/utils/test_linear_cross_entropy.py @@ -34,6 +34,7 @@ import torch import verl.utils.torch_functional as verl_F +from verl.utils.device import is_torch_npu_available from verl.utils.experimental.torch_functional import FusedLinearForPPO from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy from verl.utils.torch_functional import logprobs_from_logits @@ -348,6 +349,52 @@ def check_storage_all(self): self.check_storage("Kernel", linear_cross_entropy) +def test_lce_non_divisible_vocab_padding(): + """Regression test for the logsumexp padding bug. + + When vocab_size % BLOCK_SIZE_N != 0 the last tile has fewer than + BLOCK_SIZE_N valid entries. Without the fix, out-of-bounds positions + are loaded as weight=0 → logit=0 → exp(0)=1, adding phantom probability + mass to the logsumexp denominator. For peaked softmax distributions + (small denominator) this causes large log-prob errors. + + Reproducing construction: one token-logit at +3, all others at -15 + → denominator ≈ 20, phantom adds ≈ 25 → error ≈ 0.82 per token. + """ + if not torch.cuda.is_available() or is_torch_npu_available(check_device=False): + return + + torch.manual_seed(0) + + V = 152064 # vocab_size % 1024 == 512 (triggers bug) + V_div = 149 * 1024 # vocab_size % 1024 == 0 (control) + D = 3584 + N = 512 + T = 1.5 + + def reference(hidden, weight, labels): + h = hidden.squeeze(0).float() + logits = torch.matmul(h, weight.float().T) / T + lp = -torch.nn.functional.cross_entropy(logits, labels.squeeze(0), reduction="none") + pd = torch.nn.functional.softmax(logits, dim=-1) + ent = torch.logsumexp(logits, dim=-1) - (pd * logits).sum(-1) + return lp, ent + + for vocab_size, desc in [(V, "non-divisible vocab (mod1024=512)"), (V_div, "divisible vocab (mod1024=0)")]: + w = torch.zeros(vocab_size, D, dtype=torch.bfloat16, device="cuda") + w[:, 0] = -15.0 * T + w[0, 0] = 3.0 * T + h = torch.zeros(1, N, D, dtype=torch.bfloat16, device="cuda") + h[:, :, 0] = 1.0 + labels = torch.zeros(1, N, dtype=torch.long, device="cuda") + + ref_lp, ref_ent = reference(h, w, labels) + ker_lp, ker_ent = linear_cross_entropy(h, w, labels, T) + + torch.testing.assert_close(ref_lp, ker_lp, atol=1e-3, rtol=1e-3, msg=f"logprob mismatch: {desc}") + torch.testing.assert_close(ref_ent, ker_ent, atol=1e-3, rtol=1e-3, msg=f"entropy mismatch: {desc}") + + if __name__ == "__main__": # torch.cuda.memory._record_memory_history() @@ -358,4 +405,6 @@ def check_storage_all(self): test.verify_correctness() test.check_storage_all() + test_lce_non_divisible_vocab_padding() + # torch.cuda.memory._dump_snapshot("test_linear_cross_entropy.pkl") diff --git a/tests/workers/config/test_model_config_on_cpu.py b/tests/workers/config/test_model_config_on_cpu.py new file mode 100644 index 00000000000..e76985278ac --- /dev/null +++ b/tests/workers/config/test_model_config_on_cpu.py @@ -0,0 +1,96 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright Amazon.com and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +from omegaconf import OmegaConf + +from verl.workers.config.model import HFModelConfig + + +class TestHFModelConfigCPU: + model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B") # Just a path string, not loaded + + def test_target_modules_accepts_list_via_omegaconf(self): + """ + Test that target_modules field accepts both string and list values + when merging OmegaConf configs (simulates CLI override behavior). + + The purpose is to ensure we can pass + actor_rollout_ref.model.target_modules='["k_proj","o_proj","down_proj","q_proj"]' + """ + + # Create structured config from the dataclass defaults + # This is what omega_conf_to_dataclass does internally + cfg_from_dataclass = OmegaConf.structured(HFModelConfig) + + # Simulate CLI override with target_modules as a list + cli_config = OmegaConf.create( + { + "path": self.model_path, + "target_modules": ["k_proj", "o_proj", "q_proj", "v_proj"], + } + ) + + # This merge should NOT raise ValidationError + # Before the fix (target_modules: str), this would fail with: + # "Cannot convert 'ListConfig' to string" + merged = OmegaConf.merge(cfg_from_dataclass, cli_config) + + # Verify the list was merged correctly + assert list(merged.target_modules) == ["k_proj", "o_proj", "q_proj", "v_proj"] + + def test_target_modules_accepts_none_via_omegaconf(self): + """Test that target_modules still accepts None values.""" + + cfg_from_dataclass = OmegaConf.structured(HFModelConfig) + + cli_config = OmegaConf.create( + { + "path": self.model_path, + "target_modules": None, + } + ) + + merged = OmegaConf.merge(cfg_from_dataclass, cli_config) + assert merged.target_modules is None + + def test_target_modules_accepts_string_via_omegaconf(self): + """Test that target_modules still accepts string values.""" + + cfg_from_dataclass = OmegaConf.structured(HFModelConfig) + + cli_config = OmegaConf.create( + { + "path": self.model_path, + "target_modules": "all-linear", + } + ) + + merged = OmegaConf.merge(cfg_from_dataclass, cli_config) + assert merged.target_modules == "all-linear" + + def test_target_modules_raises_on_invalid_type(self): + """Test that __post_init__ raises TypeError for invalid target_modules types.""" + base_config = OmegaConf.structured(HFModelConfig) + invalid_cli_config = OmegaConf.create( + { + "path": self.model_path, + "target_modules": [1, 2, 3], # list of ints instead of strings + } + ) + merged_config = OmegaConf.merge(base_config, invalid_cli_config) + with pytest.raises(TypeError): + OmegaConf.to_object(merged_config) diff --git a/tests/workers/rollout/rollout_trtllm/test_adapter.py b/tests/workers/rollout/rollout_trtllm/test_adapter.py index 004df83d0eb..0acae344396 100644 --- a/tests/workers/rollout/rollout_trtllm/test_adapter.py +++ b/tests/workers/rollout/rollout_trtllm/test_adapter.py @@ -175,9 +175,9 @@ def test_init_without_device_mesh(self): worker0 = replica.workers[0] worker1 = replica.workers[1] - replica_rank = ray.get(worker0._get_attribute.remote("replica_rank")) - is_leader_rank_0 = ray.get(worker0._get_attribute.remote("is_leader_rank")) - is_leader_rank_1 = ray.get(worker1._get_attribute.remote("is_leader_rank")) + replica_rank = ray.get(worker0.get_replica_rank.remote()) + is_leader_rank_0 = ray.get(worker0.is_leader_rank.remote()) + is_leader_rank_1 = ray.get(worker1.is_leader_rank.remote()) assert replica_rank == 0 assert is_leader_rank_0 is True diff --git a/verl/checkpoint_engine/README.md b/verl/checkpoint_engine/README.md index 2318dd9477d..5cf1ece8c06 100644 --- a/verl/checkpoint_engine/README.md +++ b/verl/checkpoint_engine/README.md @@ -18,16 +18,27 @@ Checkpoint Engine is an unified abstract layer to synchronize weights between va |nccl|NCCL|all_gather+broadcast|NVIDIA GPU & NCCL|Very High|Low: rebuild nccl group|Off-policy training
- Trainer/rollout disaggregated
- Fixed clusters |hccl|HCCL|all_gather+broadcast|Ascend NPU & HCCL| High|Low: rebuild hccl group|Off-policy training
- Trainer/rollout disaggregated
- Fixed clusters |nixl|NIXL|all_gather+ring p2p|Various transport backends (D2D, H2H, H2D, etc)
- UCX
- UCCL
- Mooncacke|Medium/High|High: dynamic adjust ring topology|Off-policy training
- Trainer/rollout disaggregated
- Elastic rollout
- Rollout fault tolerance
- Heterogeneous hardware rollout +|kimi_ckpt_engine|MOONCAKE+NCCL/HCCL|p2p+broadcast|NVIDIA/Ascend|High|Low: rebuild communication group|Off-policy training
- Trainer/rollout disaggregated
- Save checkpoint each time + +##### kimi_ckpt_engine detail: + +In the kimi_ckpt_engine workflow, the trainer first offloads the weights to the CPU, and the rollout creates a sub communication group that includes all the cards for the rollout. Then, using Mooncake transfer engine, these weights are transmitted via P2P to a specific worker in the rollout, followed by a broadcast to all other rollout workers. + +kimi-ckpt-engine + +This mode requires the P2P feature of checkpoint_engine. Please ensure you have installed it via pip install 'checkpoint-engine[p2p]' and that your version is 0.4.0 or higher. + +In addition, during the installation of checkpoint-engine[p2p], the transfer engine will be installed. However, This library has no prebuilt packages for Ascend devices and must be compiled from source. For detailed compilation instructions, see: [transfer-engine: ascend direct](https://github.com/kvcache-ai/Mooncake/blob/main/docs/source/design/transfer-engine/ascend_direct_transport.md) ### Benchmark 1. benchmark setup - model: Qwen/Qwen3-30B-A3B-Base -- trainer: fsdp world_size=2 +- trainer: fsdp world_size=2 (since Ascend 910C has 64GB of HBM, we set world_size=4) - rollout: num_rollout=30 (only receive weight without cuda ipc to vllm/sglang) ```bash -python3 tests/checkpoint_engine/test_nixl_checkpoint_engine.py -python3 tests/checkpoint_engine/test_nccl_checkpoint_engine.py -python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py +pytest tests/checkpoint_engine/test_correctness_on_gpu.py +pytest tests/checkpoint_engine/test_correctness_on_npu.py +pytest tests/checkpoint_engine/test_special_server_adapter.py ``` 2. benchmark result @@ -36,4 +47,5 @@ python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py |----|----|----|----| |4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NCCL | ~7 | 8.25| |4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NIXL | ~7 | 8.25| -|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3| \ No newline at end of file +|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3| +|2*16 Ascend 910C, inner suppernode| kimi_ckpt_engine | offload: 7 update: 3.5 | 16.5| diff --git a/verl/checkpoint_engine/__init__.py b/verl/checkpoint_engine/__init__.py index 4409369e8e8..e0c827aec7d 100644 --- a/verl/checkpoint_engine/__init__.py +++ b/verl/checkpoint_engine/__init__.py @@ -44,10 +44,16 @@ except ImportError: HCCLCheckpointEngine = None - try: from .nixl_checkpoint_engine import NIXLCheckpointEngine __all__ += ["NIXLCheckpointEngine"] except ImportError: NIXLCheckpointEngine = None + +try: + from .kimi_checkpoint_engine import KIMICheckpointEngine + + __all__ += ["KIMICheckpointEngine"] +except ImportError: + KIMICheckpointEngine = None diff --git a/verl/checkpoint_engine/base.py b/verl/checkpoint_engine/base.py index f3a89c67d95..f722c1f4948 100644 --- a/verl/checkpoint_engine/base.py +++ b/verl/checkpoint_engine/base.py @@ -23,7 +23,7 @@ from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup from verl.utils.distributed import initialize_global_process_group_ray from verl.utils.ray_utils import auto_await -from verl.workers.config import HFModelConfig, RolloutConfig +from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig from verl.workers.rollout import BaseRollout, RolloutReplica, get_rollout_class @@ -255,20 +255,32 @@ def __init__( rollout_config: RolloutConfig, model_config: HFModelConfig, server_adapter: BaseRollout = None, + *args, + **kwargs, ) -> None: + super().__init__() self.rollout_config = rollout_config self.model_config = model_config + self.server_adapter: BaseRollout = server_adapter + backend = self.rollout_config.checkpoint_engine.backend + bucket_size = self.rollout_config.checkpoint_engine.update_weights_bucket_megabytes << 20 + engine_kwargs = self.rollout_config.checkpoint_engine.engine_kwargs.get(backend, {}) + self.checkpoint_engine: CheckpointEngine = CheckpointEngineRegistry.new( + backend, bucket_size=bucket_size, **engine_kwargs + ) + self.extra_rollout_args = args + self.extra_rollout_kwargs = kwargs + if self.server_adapter is None: + self.server_adapter = get_rollout_class(self.rollout_config.name, self.rollout_config.mode)( + *self.extra_rollout_args, + config=self.rollout_config, + model_config=self.model_config, + device_mesh=None, + **self.extra_rollout_kwargs, + ) # sglang and trt-llm need device_mesh for internal communication initialize_global_process_group_ray(timeout_second=None, backend="cpu:gloo") - self.server_adapter: BaseRollout = server_adapter or get_rollout_class( - rollout_config.name, rollout_config.mode - )(config=rollout_config, model_config=model_config, device_mesh=None) - - backend = rollout_config.checkpoint_engine.backend - bucket_size = rollout_config.checkpoint_engine.update_weights_bucket_megabytes << 20 - engine_kwargs = rollout_config.checkpoint_engine.engine_kwargs.get(backend, {}) - self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs) @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) async def update_weights(self): @@ -279,6 +291,16 @@ async def update_weights(self): def execute_checkpoint_engine(self, method: str, *args, **kwargs): return getattr(self.checkpoint_engine, method)(*args, **kwargs) + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_replica_rank(self) -> int: + """Get replica rank from the underlying rollout server adapter.""" + return self.server_adapter.replica_rank + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def is_leader_rank(self) -> bool: + """Get leader rank flag from the underlying rollout server adapter.""" + return self.server_adapter.is_leader_rank + _worker_cls = ray.remote(CheckpointEngineWorker) @@ -307,19 +329,20 @@ class CheckpointEngineManager: ``` Args: - backend: The checkpoint engine backend. + config: The checkpoint engine config. trainer: The trainer worker group. replicas: The list of rollout replicas. """ def __init__( self, - backend: str, + config: CheckpointEngineConfig, trainer: RayWorkerGroup, replicas: list[RolloutReplica], ) -> None: - self.backend = backend - self.backend_cls = CheckpointEngineRegistry.get(backend) + self.config = config + self.backend = config.backend + self.backend_cls = CheckpointEngineRegistry.get(config.backend) self.trainer = trainer self.replicas = replicas diff --git a/verl/checkpoint_engine/hccl_checkpoint_engine.py b/verl/checkpoint_engine/hccl_checkpoint_engine.py index 18366d6b2a2..c4839999ddf 100644 --- a/verl/checkpoint_engine/hccl_checkpoint_engine.py +++ b/verl/checkpoint_engine/hccl_checkpoint_engine.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import logging import os import time @@ -67,8 +66,7 @@ def __init__( self.socket = socket self.topic = topic - loop = asyncio.get_running_loop() - self._task = loop.run_in_executor(None, self._run) + self._run() def _run(self): # broadcast tensor meta via zeromq PUB/SUB @@ -88,7 +86,6 @@ async def wait_for_complete(self) -> dict[str, TensorMeta]: Returns: dict[str, TensorMeta]: The bucket meta after broadcast. """ - await self._task return self.metadata @@ -148,6 +145,7 @@ def finalize(self): self.send_buf = None self.recv_buf = None + torch.npu.empty_cache() @classmethod def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): @@ -165,7 +163,7 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada def _start_zmq_server(self): self.ip = ray.util.get_node_ip_address().strip("[]") - self.zmq_port, self.listen_sock = get_free_port(self.ip) + self.zmq_port, _ = get_free_port(self.ip) context = zmq.Context() self.socket = context.socket(zmq.PUB) diff --git a/verl/checkpoint_engine/kimi_checkpoint_engine.py b/verl/checkpoint_engine/kimi_checkpoint_engine.py new file mode 100644 index 00000000000..f042c3489d8 --- /dev/null +++ b/verl/checkpoint_engine/kimi_checkpoint_engine.py @@ -0,0 +1,384 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import concurrent.futures +import logging +import os +import time +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import AsyncGenerator, Generator + +import checkpoint_engine.distributed as dist +import ray +import torch +from checkpoint_engine.ps import H2DBucket, ParameterMeta, ParameterServer, _gen_h2d_buckets, _to_named_tensor + +from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry +from verl.utils.device import get_nccl_backend, get_torch_device +from verl.utils.net_utils import get_free_port + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def ckpt_get_named_tensor_buckets( + iterable: Generator[tuple[str, torch.Tensor], None, None], + bucket_bytes: int, + world_size: int, + rank_id: int, + rollout_dtype: torch.dtype = torch.bfloat16, +) -> dict[str, torch.Tensor]: + if bucket_bytes <= 0: + raise ValueError(f"bucket_bytes must be greater than 0, got {bucket_bytes}") + + current_bucket = {} + current_size = 0 + for tensor_idx, (name, tensor) in enumerate(iterable): + tensor = tensor.to(rollout_dtype) + if tensor_idx % world_size == rank_id: + tensor_size = tensor.element_size() * tensor.numel() + if current_size + tensor_size > bucket_bytes: + if current_bucket: + yield current_bucket + current_bucket = {} + current_size = 0 + + current_bucket[name] = tensor + current_size += tensor_size + + if current_bucket: + yield current_bucket + + +async def receive_tensor( + self, + checkpoint_name: str, + ranks_group: int, + ranks: list[int] | None = None, + bucket_size: int = 2 << 30, + disable_h2d_buffer: bool = False, +) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty" + assert dist.is_initialized(), "process group is not initialized" + assert self._p2p_store is not None, "p2p store is not initialized" + assert ranks, "ranks should be set" + + # first execute a barrier to avoid subsequent device oom + dist.barrier(group=ranks_group) + buckets = _gen_h2d_buckets( + self._current_global_parameter_metas, + bucket_size, + self._local_rdma_devices, + self._remote_rdma_devices, + ranks, + ) + h2d_buffer: torch.Tensor | None = ( + None + if disable_h2d_buffer + else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type) + ) + # p2p store need to register h2d_buffer to let other ranks read + if ranks: + h2d_buffer_name = "__h2d_buffer__" + if h2d_buffer is not None and self._p2p_store is not None: + self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer}) + receiver_rank_buckets: list[tuple[int, H2DBucket]] = [] + for receiver_rank, owner_rank, bucket in buckets: + if receiver_rank != self._rank: + continue + receiver_rank_buckets.append((owner_rank, bucket)) + buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type) + buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list) + + max_len = 0 + for receiver_rank, _, bucket in buckets: + buckets_by_receiver_rank[receiver_rank].append(bucket) + if len(buckets_by_receiver_rank[receiver_rank]) > max_len: + max_len = len(buckets_by_receiver_rank[receiver_rank]) + gidx = 0 + metadata: list[ParameterMeta] + try: + for i in range(max_len): + if i < len(receiver_rank_buckets) and not disable_h2d_buffer: + self._copy_to_buffer( + checkpoint_name, + receiver_rank_buckets[i][1], + h2d_buffer, + receiver_rank_buckets[i][0] if ranks else None, + ) + for receiver_rank, _buckets in buckets_by_receiver_rank.items(): + if i >= len(_buckets): + continue + bucket = _buckets[i] + start = gidx % 2 * bucket_size + buffer_b: torch.Tensor = buffer[start : start + bucket.size] + if receiver_rank == self._rank: + if disable_h2d_buffer: + self._copy_to_buffer(checkpoint_name, bucket, buffer_b) + else: + buffer_b.data.copy_(h2d_buffer[: bucket.size]) + broadcast_op = BroadcastOperation( + rank=receiver_rank, + ranks_group=ranks_group, + bucket=buffer_b, + metadata=bucket.items, + ) + if gidx == 0: + metadata = await broadcast_op.wait_for_complete() + gidx += 1 + continue + meta_list = _to_named_tensor(metadata, (gidx - 1) % 2 * bucket_size) + for item in meta_list: + shape = item["shape"] + if isinstance(shape, list | tuple): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + yield item["name"], tensor + metadata = await broadcast_op.wait_for_complete() + self.device_manager.device_module.synchronize() + gidx += 1 + + meta_list = _to_named_tensor(metadata, (gidx - 1) % 2 * bucket_size) + for item in meta_list: + shape = item["shape"] + if isinstance(shape, list | tuple): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + yield item["name"], tensor + + finally: + dist.barrier(group=ranks_group) + if ranks and h2d_buffer is not None: + self._p2p_store.unregister_named_tensors([h2d_buffer_name]) + self.device_manager.device_module.empty_cache() + + +@dataclass +class MasterMetadata: + zmq_ip: str + zmq_port: int + dist_ip: str + dist_port: int + + +class BroadcastOperation: + """Async broadcast operation with NCCL in separate thread. + + Args: + rank (int): The rank of the current process. + ranks_group (int): The process group's value. + bucket (torch.Tensor): The tensor to broadcast. + metadata (list[ParameterMeta]): The metadata of the tensor. + """ + + def __init__( + self, + rank: int, + ranks_group: int, + bucket: torch.Tensor, + metadata: list[ParameterMeta], + ) -> None: + self.rank = rank + self.ranks_group = ranks_group + self.bucket = bucket + self.metadata = metadata + + loop = asyncio.get_running_loop() + self._task = loop.run_in_executor(None, self._run) + + def _run(self): + # broadcast tensor + dist.broadcast(self.bucket, src=self.rank, group=self.ranks_group) + + async def wait_for_complete(self) -> list[ParameterMeta]: + """Wait for the broadcast operation to complete. + + Returns: + list[ParameterMeta]: The bucket meta after broadcast. + """ + await self._task + return self.metadata + + +@CheckpointEngineRegistry.register("kimi_ckpt_engine") +class KIMICheckpointEngine(CheckpointEngine): + """NCCL checkpoint engine with collective communication. + + Args: + bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use + two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size. + rebuild_group (bool): Whether to rebuild the NCCL process group in each update. Defaults to False. + is_master (bool): Whether the current process is the master process. Defaults to False. + rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16. + """ + + def __init__( + self, + bucket_size: int, + rebuild_group: bool = False, + is_master: bool = False, + rollout_dtype: torch.dtype = torch.bfloat16, + ) -> None: + self.bucket_size = bucket_size + self.rebuild_group = rebuild_group + self.rollout_dtype = rollout_dtype + self.is_master = is_master + self.initialized = False + self.checkpoint_name = "kimi_checkpoint_engine" + + def prepare(self) -> MasterMetadata: + if self.is_master: + self.ip = ray.util.get_node_ip_address().strip("[]") + self.listen_port, _ = get_free_port(self.ip) + + return ( + MasterMetadata(zmq_ip=None, zmq_port=None, dist_ip=self.ip, dist_port=self.listen_port) + if self.is_master + else None + ) + + def finalize(self): + """Destroy the ckpt engine process group if rebuild_group is True.""" + if self.rebuild_group: + dist.destroy_process_group() + self.rank = None + self.world_size = None + self.initialized = False + + @classmethod + def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): + trainer_kwargs = { + "method": ["init_process_group"] * trainer_world_size, + "rank": list(range(0, trainer_world_size)), + "trainer_world_size": [trainer_world_size] * trainer_world_size, + "rollout_world_size": [rollout_world_size] * trainer_world_size, + "master_metadata": [metadata[0]] * trainer_world_size, + } + rollout_kwargs = { + "method": ["init_process_group"] * rollout_world_size, + "rank": list(range(trainer_world_size, trainer_world_size + rollout_world_size)), + "trainer_world_size": [trainer_world_size] * rollout_world_size, + "rollout_world_size": [rollout_world_size] * rollout_world_size, + "master_metadata": [metadata[0]] * rollout_world_size, + } + return trainer_kwargs, rollout_kwargs + + def init_process_group( + self, + rank: int, + trainer_world_size: int, + rollout_world_size: int, + master_metadata: MasterMetadata, + ): + """Initialize the ckpt engine process group. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes. + """ + self.rank = rank + self.trainer_world_size = trainer_world_size + self.rollout_world_size = rollout_world_size + self.world_size = trainer_world_size + rollout_world_size + + if not self.initialized: + self.parameter_server = ParameterServer( + rank=rank, + world_size=self.world_size, + auto_pg=False, + master_addr=master_metadata.dist_ip, + master_port=master_metadata.dist_port, + ) + self.parameter_server.receive_tensor = types.MethodType(receive_tensor, self.parameter_server) + + dist.use_backend(f"vllm_{get_nccl_backend()}") + self.parameter_server.init_process_group() + + self.rollout_ranks = list(range(self.trainer_world_size, self.world_size)) + self.rollout_group = dist.new_group(self.rollout_ranks) + self.initialized = True + + @torch.no_grad() + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send the weights of the model. + + Args: + weights: A generator that yields the name of the weight tensor and the tensor itself. + """ + + def offload_cpu(name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]: + return name, tensor.to("cpu", non_blocking=True) + + start_time = time.time() + named_tensors = {} + for named_tensors_gpu in ckpt_get_named_tensor_buckets( + weights, self.bucket_size, self.trainer_world_size, self.rank, self.rollout_dtype + ): + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + futures = [ + executor.submit( + offload_cpu, + name, + tensor, + ) + for name, tensor in named_tensors_gpu.items() + ] + for future in concurrent.futures.as_completed(futures): + name, tensor_cpu = future.result() + named_tensors[name] = tensor_cpu + + get_torch_device().synchronize() + + self.parameter_server.register_checkpoint(self.checkpoint_name, named_tensors=named_tensors) + named_tensors = {} + get_torch_device().empty_cache() + logger.info(f"Rank {self.rank} offload and register, time cost: {time.time() - start_time:.2f}s") + + self.parameter_server.gather_metas(self.checkpoint_name) + dist.barrier() + self.parameter_server.unregister_checkpoint(self.checkpoint_name) + logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s") + + @torch.no_grad() + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + """Receive the weights of the model. + + Yields: + A tuple of the name of the weight tensor and the tensor itself. + """ + self.parameter_server.gather_metas(self.checkpoint_name) + + start_time = time.time() + total_bytes, total_params = 0, 0 + async for name, tensor in self.parameter_server.receive_tensor( + self.checkpoint_name, self.rollout_group, self.rollout_ranks, self.bucket_size + ): + total_bytes += tensor.element_size() * tensor.nelement() + total_params += 1 + yield name, tensor + dist.barrier() + time_cost = time.time() - start_time + bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) + logger.info( + f"Rank {self.rank} receive weights done, total_params: {total_params}, " + f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" + ) diff --git a/verl/checkpoint_engine/nccl_checkpoint_engine.py b/verl/checkpoint_engine/nccl_checkpoint_engine.py index e3f8d99447a..279733900d6 100644 --- a/verl/checkpoint_engine/nccl_checkpoint_engine.py +++ b/verl/checkpoint_engine/nccl_checkpoint_engine.py @@ -148,6 +148,8 @@ def finalize(self): self.send_buf = None self.recv_buf = None + torch.cuda.empty_cache() + @classmethod def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): trainer_kwargs = { @@ -164,7 +166,7 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada def _start_zmq_server(self): self.ip = ray.util.get_node_ip_address().strip("[]") - self.listen_port, self.listen_sock = get_free_port(self.ip) + self.listen_port, _ = get_free_port(self.ip) context = zmq.Context() self.socket = context.socket(zmq.PUB) diff --git a/verl/checkpoint_engine/nixl_checkpoint_engine.py b/verl/checkpoint_engine/nixl_checkpoint_engine.py index 5fd8c44f509..fbdefc5b230 100644 --- a/verl/checkpoint_engine/nixl_checkpoint_engine.py +++ b/verl/checkpoint_engine/nixl_checkpoint_engine.py @@ -82,7 +82,7 @@ def get_agent_metadata(self) -> NixlAgentMetadata: def start_zmq_server(self): self.ip = ray.util.get_node_ip_address().strip("[]") - self.listen_port, self.listen_sock = get_free_port(self.ip) + self.listen_port, _ = get_free_port(self.ip) context = zmq.asyncio.Context() self.socket = context.socket(zmq.PULL) diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 228d2248b7e..6dd3871c34e 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -35,24 +35,31 @@ from verl.experimental.agent_loop.utils import resolve_config_path from verl.protocol import DataProto from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup -from verl.utils import hf_processor, hf_tokenizer from verl.utils.chat_template import initialize_system_prompt +from verl.utils.config import omega_conf_to_dataclass from verl.utils.dataset.rl_dataset import RLHFDataset, get_dataset_class -from verl.utils.fs import copy_to_local from verl.utils.model import compute_position_id_with_mask -from verl.utils.ray_utils import get_event_loop +from verl.utils.ray_utils import auto_await, get_event_loop from verl.utils.rollout_trace import ( RolloutTraceConfig, rollout_trace_attr, rollout_trace_op, ) -from verl.utils.transferqueue_utils import tqbridge +from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import TokenOutput, get_rollout_replica_class logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +def _get_rollout_and_model_config(config: DictConfig) -> tuple[DictConfig, DictConfig]: + # TODO: backward compatibility, remove this once we switch to new trainer. + if config.get("actor_rollout_ref"): + return config.actor_rollout_ref.rollout, config.actor_rollout_ref.model + else: + return config.rollout, config.model + + class AsyncLLMServerManager: """ A class to manage multiple OpenAI compatible LLM servers. This class provides @@ -64,7 +71,7 @@ def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandl """Initialize the AsyncLLMServerManager. Args: - config (DictConfig): YAML config. + config (DictConfig): whole config for main entrypoint. server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000. """ @@ -190,7 +197,16 @@ def __init__(self, config: DictConfig): class AgentLoopBase(ABC): """An agent loop takes an input message, chat with OpenAI compatible LLM server and interact with various - environments.""" + environments. + + Args: + trainer_config (DictConfig): whole config for main entrypoint. + server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager. + tokenizer (AutoTokenizer): Tokenizer for tokenize messages. + processor (AutoProcessor): Processor for process messages. + dataset_cls (type[Dataset]): Dataset class for creating dataset, Defaults to RLHFDataset. + data_config (DictConfigWrap): Dataset config. + """ def __init__( self, @@ -199,26 +215,17 @@ def __init__( tokenizer: AutoTokenizer, processor: AutoProcessor, dataset_cls: type[RLHFDataset], - dataset_config: DictConfigWrap, + data_config: DictConfigWrap, **kwargs, ): - """Initialize agent loop, each sample will have its own loop instance. - - Args: - trainer_config (DictConfigWrap): trainer config. - server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager. - tokenizer (AutoTokenizer): Tokenizer for tokenize messages. - processor (AutoProcessor): Processor for process messages. - dataset_cls (type[Dataset]): Dataset class for creating dataset, Defaults to RLHFDataset. - dataset_config (DictConfigWrap): Dataset config. - """ self.config = trainer_config.config + self.rollout_config, _ = _get_rollout_and_model_config(self.config) self.server_manager = server_manager self.tokenizer = tokenizer self.processor = processor self.dataset_cls = dataset_cls - self.dataset_config = dataset_config.config - self.apply_chat_template_kwargs = self.dataset_config.get("apply_chat_template_kwargs", {}) + self.data_config = data_config.config + self.apply_chat_template_kwargs = self.data_config.get("apply_chat_template_kwargs", {}) self.system_prompt = initialize_system_prompt(self.tokenizer, **self.apply_chat_template_kwargs) self.loop = get_event_loop() @@ -234,7 +241,7 @@ async def process_vision_info(self, messages: list[dict]) -> dict: multi_modal_data = {} if self.processor is not None: images, videos = await self.dataset_cls.process_vision_info( - messages, image_patch_size=self.processor.image_processor.patch_size, config=self.dataset_config + messages, image_patch_size=self.processor.image_processor.patch_size, config=self.data_config ) if images is not None: multi_modal_data["images"] = images @@ -342,7 +349,13 @@ def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]: class AgentLoopWorker: - """Agent loop worker takes a batch of messages and run each message in an agent loop.""" + """Agent loop worker takes a batch of messages and run each message in an agent loop. + + Args: + config (DictConfig): whole config for main entrypoint. + server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. + reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation. + """ def __init__( self, @@ -350,13 +363,10 @@ def __init__( server_handles: list[ray.actor.ActorHandle], reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, ): - """Initialize agent loop manager. - Args: - config (DictConfig): YAML config. - server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. - reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation. - """ self.config = config + rollout_config, model_config = _get_rollout_and_model_config(config) + self.rollout_config: RolloutConfig = omega_conf_to_dataclass(rollout_config) + self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config) # for recipe to change if not hasattr(self, "server_manager"): @@ -365,33 +375,29 @@ def __init__( self.dataset_cls = get_dataset_class(config.data) self.reward_loop_worker_handles = reward_loop_worker_handles - model_path = config.actor_rollout_ref.model.path - self.model_name = "/".join(model_path.split("/")[-2:]) - local_path = copy_to_local(config.actor_rollout_ref.model.path) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) - self.processor = hf_processor(local_path, trust_remote_code=True) + self.tokenizer = self.model_config.tokenizer + self.processor = self.model_config.processor - agent_loop_config_path = config.actor_rollout_ref.rollout.agent.agent_loop_config_path + agent_loop_config_path = self.rollout_config.agent.agent_loop_config_path if agent_loop_config_path: resolved_path = resolve_config_path(agent_loop_config_path) agent_loop_configs = OmegaConf.load(resolved_path) for agent_loop_config in agent_loop_configs: _agent_loop_registry[agent_loop_config.name] = agent_loop_config - if self.config.actor_rollout_ref.model.get("custom_chat_template", None) is not None: - if self.processor is not None: - self.processor.chat_template = self.config.actor_rollout_ref.model.custom_chat_template - self.tokenizer.chat_template = self.config.actor_rollout_ref.model.custom_chat_template + if self.model_config.get("custom_chat_template", None) is not None: + if self.model_config.processor is not None: + self.model_config.processor.chat_template = self.model_config.custom_chat_template + self.model_config.tokenizer.chat_template = self.model_config.custom_chat_template - trace_config = self.config.actor_rollout_ref.rollout.get("trace", {}) + trace_config = self.rollout_config.trace RolloutTraceConfig.init( - self.config.trainer.project_name, - self.config.trainer.experiment_name, + self.rollout_config.trace.project_name, + self.rollout_config.trace.experiment_name, trace_config.get("backend"), trace_config.get("token2text", False), trace_config.get("max_samples_per_step_per_worker", None), ) - @tqbridge() async def generate_sequences(self, batch: DataProto) -> DataProto: """Generate sequences from agent loop. @@ -413,7 +419,7 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->| response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0| """ - config = self.config.actor_rollout_ref.rollout + config = self.rollout_config sampling_params = dict( temperature=config.temperature, top_p=config.top_p, @@ -502,7 +508,7 @@ async def _run_agent_loop( tokenizer=self.tokenizer, processor=self.processor, dataset_cls=self.dataset_cls, - dataset_config=DictConfigWrap(self.config.data), + data_config=DictConfigWrap(self.config.data), ) output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs) return await self._agent_loop_postprocess(output, **kwargs) @@ -536,7 +542,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO prompt_output = self.tokenizer.pad( {"input_ids": output.prompt_ids}, padding="max_length", - max_length=self.config.actor_rollout_ref.rollout.prompt_length, + max_length=self.rollout_config.prompt_length, return_tensors="pt", return_attention_mask=True, ) @@ -548,7 +554,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO response_output = self.tokenizer.pad( {"input_ids": output.response_ids}, padding="max_length", - max_length=self.config.actor_rollout_ref.rollout.response_length, + max_length=self.rollout_config.response_length, return_tensors="pt", return_attention_mask=True, ) @@ -559,7 +565,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO response_mask_output = self.tokenizer.pad( {"input_ids": output.response_mask}, padding="max_length", - max_length=self.config.actor_rollout_ref.rollout.response_length, + max_length=self.rollout_config.response_length, return_tensors="pt", return_attention_mask=False, ) @@ -568,7 +574,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO response_logprobs = None if output.response_logprobs is not None: - pad_size = self.config.actor_rollout_ref.rollout.response_length - len(output.response_logprobs) + pad_size = self.rollout_config.response_length - len(output.response_logprobs) response_logprobs = torch.tensor(output.response_logprobs + [0.0] * pad_size).unsqueeze(0) response_mask = response_mask_output["input_ids"] * response_output["attention_mask"] @@ -777,8 +783,17 @@ def _postprocess( metrics = [input.metrics.model_dump() for input in inputs] # Collect extra fields from all inputs and convert them to np.ndarray + # Keep a stable set of keys so downstream batch concat stays consistent across agent loops. extra_fields = {} - all_keys = set(key for input_item in inputs for key in input_item.extra_fields) + default_extra_keys = { + "turn_scores", + "tool_rewards", + "is_cancel", + "param_version_start", + "param_version_end", + "extras", + } + all_keys = set(key for input_item in inputs for key in input_item.extra_fields) | default_extra_keys for key in all_keys: temp_arr = np.empty(len(inputs), dtype=object) temp_arr[:] = [input.extra_fields.get(key) for input in inputs] @@ -799,20 +814,6 @@ def _postprocess( meta_info=meta_info, ) - def create_transferqueue_client( - self, - ): - """Create a client for data system (TransferQueue).""" - from verl.single_controller.ray.base import get_random_string - from verl.utils.transferqueue_utils import create_transferqueue_client - - client_name = get_random_string(length=6) - - self.tq_client = create_transferqueue_client( - client_id=f"AgentLoopWorker_{client_name}", - config=self.config.transfer_queue, - ) - async def get_trajectory_info(step, index, validate): """Get trajectory info. @@ -837,7 +838,17 @@ async def get_trajectory_info(step, index, validate): class AgentLoopManager: - """Agent loop manager that manages a group of agent loop workers.""" + """Agent loop manager that manages a group of agent loop workers. + + - if worker_group is not None, rollout server is in hybrid mode, share GPUs with training engine. + - otherwise, rollout server is in standalone mode, use separate GPUs, e.g., one-step-off/fully async training. + + Args: + config (DictConfig): whole config for main entrypoint. + worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode. + rollout_resource_pool (RayResourcePool): Resource pool for hybrid mode, only used by TensorRT-LLM. + reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation. + """ def __init__( self, @@ -846,63 +857,70 @@ def __init__( rollout_resource_pool: RayResourcePool = None, reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, ): - """Initialize agent loop manager. - - Args: - config (DictConfig): trainer config. - worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode. - rollout_resource_pool (RayResourcePool): Resource pool for actor rollout (Colocate or Standalone mode). - reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation. - """ self.config = config + self.rollout_config, self.model_config = _get_rollout_and_model_config(config) self.worker_group = worker_group + self.rollout_resource_pool = rollout_resource_pool self.reward_loop_worker_handles = reward_loop_worker_handles + assert worker_group is not None or self.rollout_config.nnodes > 0, "nnodes must be > 0 in standalone mode" + # for recipe to change if not hasattr(self, "rollout_replica_class"): - self.rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name) + self.rollout_replica_class = get_rollout_replica_class(self.rollout_config.name) if not hasattr(self, "agent_loop_workers_class"): self.agent_loop_workers_class = ray.remote(AgentLoopWorker) - self._initialize_llm_servers(rollout_resource_pool) - self._init_agent_loop_workers() + @classmethod + @auto_await + async def create( + cls, + config: DictConfig, + worker_group: RayWorkerGroup = None, + rollout_resource_pool: RayResourcePool = None, + reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, + ): + """Create agent loop manager.""" + instance = cls(config, worker_group, rollout_resource_pool, reward_loop_worker_handles) + await instance._initialize_llm_servers() + await instance._init_agent_loop_workers() + return instance - def _initialize_llm_servers(self, rollout_resource_pool: RayResourcePool): + async def _initialize_llm_servers(self): rollout_world_size = ( - self.config.actor_rollout_ref.rollout.tensor_model_parallel_size - * self.config.actor_rollout_ref.rollout.data_parallel_size - * self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size + self.rollout_config.tensor_model_parallel_size + * self.rollout_config.data_parallel_size + * self.rollout_config.pipeline_model_parallel_size ) world_size = ( self.worker_group.world_size if self.worker_group - else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes + else self.rollout_config.n_gpus_per_node * self.rollout_config.nnodes ) num_replicas = world_size // rollout_world_size - rollout_config = self.config.actor_rollout_ref.rollout - model_config = self.config.actor_rollout_ref.model self.rollout_replicas = [ self.rollout_replica_class( replica_rank=replica_rank, - config=rollout_config, - model_config=model_config, - gpus_per_node=self.config.trainer.n_gpus_per_node, + config=self.rollout_config, + model_config=self.model_config, + gpus_per_node=self.rollout_config.n_gpus_per_node, ) for replica_rank in range(num_replicas) ] - if self.worker_group and rollout_config.name != "trtllm": - self._run_all([server.init_hybrid(self.worker_group) for server in self.rollout_replicas]) - elif self.worker_group and rollout_config.name == "trtllm": - self._run_all( - [ - server.init_hybrid_colocated(self.worker_group, rollout_resource_pool) + if self.worker_group and self.rollout_config.name != "trtllm": + await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas]) + # TODO: unify trtllm to init_hybrid + elif self.worker_group and self.rollout_config.name == "trtllm": + await asyncio.gather( + *[ + server.init_hybrid_colocated(self.worker_group, self.rollout_resource_pool) for server in self.rollout_replicas ] ) else: - self._run_all([server.init_standalone() for server in self.rollout_replicas]) + await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas]) self.server_handles = [server._server_handle for server in self.rollout_replicas] self.server_addresses = [server._server_address for server in self.rollout_replicas] @@ -910,14 +928,14 @@ def _initialize_llm_servers(self, rollout_resource_pool: RayResourcePool): print(f"AgentLoopManager: {self.server_addresses}") # Update Prometheus configuration with server addresses - if rollout_config.prometheus.enable: - if rollout_config.disable_log_stats: + if self.rollout_config.prometheus.enable: + if self.rollout_config.disable_log_stats: raise ValueError("PROMETHEUS needs disable_log_stats==False, but it is currently True.") - update_prometheus_config(rollout_config.prometheus, self.server_addresses, rollout_config.name) + update_prometheus_config(self.rollout_config.prometheus, self.server_addresses, self.rollout_config.name) - def _init_agent_loop_workers(self): + async def _init_agent_loop_workers(self): self.agent_loop_workers = [] - num_workers = self.config.actor_rollout_ref.rollout.agent.num_workers + num_workers = self.rollout_config.agent.num_workers node_ids = [node["NodeID"] for node in ray.nodes() if node["Alive"] and node["Resources"].get("CPU", 0) > 0] for i in range(num_workers): @@ -932,7 +950,8 @@ def _init_agent_loop_workers(self): ).remote(self.config, self.server_handles, self.reward_loop_worker_handles) ) - def generate_sequences(self, prompts: DataProto) -> DataProto: + @auto_await + async def generate_sequences(self, prompts: DataProto) -> DataProto: """Split input batch and dispatch to agent loop workers. Args: @@ -943,8 +962,8 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: """ chunkes = prompts.chunk(len(self.agent_loop_workers)) - outputs = ray.get( - [ + outputs = await asyncio.gather( + *[ worker.generate_sequences.remote(chunk) for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) ] @@ -985,20 +1004,17 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data return timing - def clear_kv_cache(self): + @auto_await + async def clear_kv_cache(self): """Clear all rollout kv cache, but don`t sleep.""" - self._run_all([replica.clear_kv_cache() for replica in self.rollout_replicas]) + await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas]) - def start_profile(self, **kwargs): + @auto_await + async def start_profile(self, **kwargs): """Start profiling on all rollout replicas.""" - self._run_all([replica.start_profile(**kwargs) for replica in self.rollout_replicas]) + await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas]) - def stop_profile(self): + @auto_await + async def stop_profile(self): """Stop profiling on all rollout replicas.""" - self._run_all([replica.stop_profile() for replica in self.rollout_replicas]) - - def _run_all(self, tasks: list[asyncio.Task]): - async def run_all(): - await asyncio.gather(*tasks) - - asyncio.run(run_all()) + await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas]) diff --git a/verl/experimental/agent_loop/single_turn_agent_loop.py b/verl/experimental/agent_loop/single_turn_agent_loop.py index 7c479362aa4..6ad3aa429b3 100644 --- a/verl/experimental/agent_loop/single_turn_agent_loop.py +++ b/verl/experimental/agent_loop/single_turn_agent_loop.py @@ -17,7 +17,6 @@ from uuid import uuid4 from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register -from verl.tools.utils.tool_registry import initialize_tools_from_config from verl.utils.profiler import simple_timer logger = logging.getLogger(__file__) @@ -30,12 +29,8 @@ class SingleTurnAgentLoop(AgentLoopBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length - self.response_length = self.config.actor_rollout_ref.rollout.response_length - - tool_config_path = self.config.data.tool_config_path - tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] - self.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] + self.prompt_length = self.rollout_config.prompt_length + self.response_length = self.rollout_config.response_length async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: messages = list(kwargs["raw_prompt"]) @@ -48,7 +43,6 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu # 2. apply chat template and tokenize prompt_ids = await self.apply_chat_template( messages, - tools=self.tool_schemas, images=images, videos=videos, ) @@ -81,4 +75,8 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu num_turns=2, metrics=metrics, ) + + # keeping the schema consistent with tool_agent_loop + output.extra_fields.update({"turn_scores": [], "tool_rewards": []}) + return output diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py index f98485a6781..c649a2fc3fd 100644 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -21,13 +21,10 @@ import torch from PIL import Image -from transformers import AutoProcessor, AutoTokenizer from verl.experimental.agent_loop.agent_loop import ( AgentLoopBase, AgentLoopOutput, - AsyncLLMServerManager, - DictConfigWrap, register, ) from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser @@ -88,43 +85,35 @@ def __init__( # Temporary state for tool calls self.tool_calls: list[FunctionCall] = [] + self.routed_experts = None + # Extra fields for dynamic addition, e.g., tool session data self.extra_fields: dict[str, Any] = {} @register("tool_agent") class ToolAgentLoop(AgentLoopBase): - def __init__( - self, - trainer_config: DictConfigWrap, - server_manager: AsyncLLMServerManager, - tokenizer: AutoTokenizer, - processor: AutoProcessor, - **kwargs, - ): - super().__init__(trainer_config, server_manager, tokenizer, processor, **kwargs) - config = trainer_config.config + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) # Initialize tools from config file - self.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns - self.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns - self.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls - self.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length - self.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side - tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path + self.max_user_turns = self.rollout_config.multi_turn.max_user_turns + self.max_assistant_turns = self.rollout_config.multi_turn.max_assistant_turns + self.max_parallel_calls = self.rollout_config.multi_turn.max_parallel_calls + self.max_tool_response_length = self.rollout_config.multi_turn.max_tool_response_length + self.tool_response_truncate_side = self.rollout_config.multi_turn.tool_response_truncate_side + tool_config_path = self.rollout_config.multi_turn.tool_config_path tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] self.tools = {tool.name: tool for tool in tool_list} self.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] - self.tool_parser = ToolParser.get_tool_parser( - config.actor_rollout_ref.rollout.multi_turn.format, self.tokenizer - ) - self.tool_parser_name = config.actor_rollout_ref.rollout.multi_turn.format + self.tool_parser = ToolParser.get_tool_parser(self.rollout_config.multi_turn.format, self.tokenizer) + self.tool_parser_name = self.rollout_config.multi_turn.format - self.prompt_length = config.actor_rollout_ref.rollout.prompt_length - self.response_length = config.actor_rollout_ref.rollout.response_length + self.prompt_length = self.rollout_config.prompt_length + self.response_length = self.rollout_config.response_length # Initialize interactions from config file - self.interaction_config_file = config.actor_rollout_ref.rollout.multi_turn.interaction_config_path + self.interaction_config_file = self.rollout_config.multi_turn.interaction_config_path if self.interaction_config_file: self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions( self.interaction_config_file @@ -203,6 +192,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu else None, num_turns=agent_data.user_turns + agent_data.assistant_turns + 1, metrics=agent_data.metrics, + routed_experts=agent_data.routed_experts, extra_fields={}, ) output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards}) @@ -350,10 +340,14 @@ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentSt None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False) ) else: + # Note that we have to pass None to the images and videos if there are no new images / videos + # to stay compatible with downstream image processing logic! + images = new_images_this_turn if new_images_this_turn else None + videos = None response_ids = await self.apply_chat_template( add_messages, - images=new_images_this_turn, # Using local variable - videos=None, + images=images, + videos=videos, remove_system_prompt=True, ) diff --git a/verl/experimental/fully_async_policy/README.md b/verl/experimental/fully_async_policy/README.md index a24e7610102..b7ff1756459 100644 --- a/verl/experimental/fully_async_policy/README.md +++ b/verl/experimental/fully_async_policy/README.md @@ -105,9 +105,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev | `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization | | `async_training.staleness_threshold` | Freshness control | | `async_training.partial_rollout` | Whether to perform partial_rollout | -| `async_training.checkpoint_engine.enable` | Whether to use checkpoint_engine for accelerating, default `True` | -| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | When use checkpoint_engine, whether to overlap broadcast and load_weights, default `False` | -| `async_training.checkpoint_engine.device_buffer_size_M` | When use checkpoint_engine, the user-specific bucket size (MB), default `4096` | | `async_training.use_trainer_do_validate` | Whether use trainer node to do validate process, default `False` | **Further Explanation:** @@ -181,27 +178,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev mode d (async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`. -* `async_training.checkpoint_engine.enable` - - Enabling the checkpoint engine generally reduces synchronization time overhead by more than 60% compared to - the original per-tensor parameter synchronization method. However, assembling buckets incurs additional - temporary GPU memory overhead. - -* `async_training.checkpoint_engine.overlap_broadcast_and_consume` - - Enabling pipeline between the broadcast and load_weights parameters will allocate additional GPU memory. - Since the main time consumption for parameter synchronization is not in the broadcast and load_weights phases, - but in the parameter generation phase (by megatron or FSDP), this option is off by default. - -* `async_training.checkpoint_engine.device_buffer_size_M` - - It controls the size of the memory buffer used for synchronization when the checkpoint-engine is enabled. - The actual `bucket_size` = `max(device_buffer_size_M, maximum parameter tensor size)`. - * When enable `overlap_broadcast_and_consume`, the additional device memory overhead of - trainer rank is `3 * bucket_size`and rollout rank is `2 * bucket_size`。 - * When disable `overlap_broadcast_and_consume`, the additional device memory overhead of - trainer rank is `2 * bucket_size`and rollout rank is `1 * bucket_size`。 - * `async_training.use_trainer_do_validate` It controls whether to use the trainer's `do_validate` method for validation. diff --git a/verl/experimental/fully_async_policy/README_zh.md b/verl/experimental/fully_async_policy/README_zh.md index 19a257247c3..ad2e52e4167 100644 --- a/verl/experimental/fully_async_policy/README_zh.md +++ b/verl/experimental/fully_async_policy/README_zh.md @@ -82,9 +82,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev | `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 | | `async_training.staleness_threshold` | 新鲜度控制 | | `async_training.partial_rollout` | 是否进行partial_rollout | -| `async_training.checkpoint_engine.enable` | 是否开启checkpoint_engine模式的加速,默认值True | -| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | 启动checkpoint_engine时,是否在参数同步时在broadcast和加载之间使用流水,默认值False | -| `async_training.checkpoint_engine.device_buffer_size_M` | 启动checkpoint_engine时,组装的bucket的大小(MB),默认为4096 | | `async_training.use_trainer_do_validate` | 是否使用Trainer的do_validate方法进行validation,默认值False | **进一步的解释:** @@ -146,20 +143,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev 在实际测试中,我们发现,如果单次下发的样本较少,由于数据分发的顺序,会导致训练不稳定,response 长度变长。 在这里,我们额外提供 require_batches 进行流式分发,单次参与训练的样本数量控制。 -* `async_training.checkpoint_engine.enable` - - 开启checkpoint engine后,相较于原始的逐tensor的参数同步方式,同步时间开销普遍可以降低60%以上。但是组装bucket会带来额外的临时显存开销。 - -* `async_training.checkpoint_engine.overlap_broadcast_and_consume` - - 开启参数broadcast和load_weights之间的流水后,会进一步额外申请更多显存。由于目前分析参数同步的主要耗时并非来自broadcast和load_weights阶段,而是在参数生成阶段(由megatron或FSDP),因此该开关默认关闭。 - -* `async_training.checkpoint_engine.device_buffer_size_M` - - 控制开启checkpoint engine后,用于同步的显存buffer大小。实际的`bucket_size` = `max(device_buffer_size_M, 最大参数tensor size)` - * 在开启`overlap_broadcast_and_consume`时,trainer节点的临时额外显存开销为 `3 * bucket_size`, rollout节点的临时额外显存开销为`2 * bucket_size`。 - * 在关闭`overlap_broadcast_and_consume`时,trainer节点的临时额外显存开销为 `2 * bucket_size`, rollout节点的临时额外显存开销为`1 * bucket_size`。 - * `async_training.use_trainer_do_validate` 控制是否使用trainer的`do_validate`方法进行validation。 diff --git a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py index 3098e48ba96..88a012224eb 100644 --- a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py +++ b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py @@ -28,11 +28,11 @@ AsyncLLMServerManager, DictConfigWrap, _agent_loop_registry, + _get_rollout_and_model_config, get_trajectory_info, ) -from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config from verl.protocol import DataProto -from verl.single_controller.ray import RayWorkerGroup +from verl.single_controller.ray import RayResourcePool, RayWorkerGroup from verl.utils.rollout_trace import ( rollout_trace_attr, rollout_trace_op, @@ -102,7 +102,7 @@ async def generate_sequences_no_post( Returns: list[AgentLoopOutput]: List of agent loop outputs, one per sample in the batch. """ - config = self.config.actor_rollout_ref.rollout + config = self.rollout_config sampling_params = dict( temperature=config.temperature, top_p=config.top_p, @@ -191,7 +191,7 @@ async def _partial_run_agent_loop( tokenizer=self.tokenizer, processor=self.processor, dataset_cls=self.dataset_cls, - dataset_config=DictConfigWrap(config=self.config.data), + data_config=DictConfigWrap(config=self.config.data), ) output: AgentLoopOutput = await agent_loop.run( sampling_params, cancellation_event=self.cancellation_event, **kwargs @@ -219,15 +219,17 @@ def __init__( self, config: DictConfig, worker_group: RayWorkerGroup = None, + rollout_resource_pool: RayResourcePool = None, reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, ): self.config = config + self.rollout_config, self.model_config = _get_rollout_and_model_config(config) self.worker_group = worker_group self.reward_loop_worker_handles = reward_loop_worker_handles self.agent_loop_workers_class = FullyAsyncAgentLoopWorker # Select rollout replica class based on rollout name - rollout_name = config.actor_rollout_ref.rollout.name + rollout_name = self.rollout_config.name if rollout_name == "sglang": from verl.experimental.fully_async_policy.sglang_rollout.sglang_async_server import FullyAsyncSGLangReplica @@ -246,63 +248,6 @@ def __init__( self.server_addresses = None self.agent_loop_workers = None - @classmethod - async def create( - cls, - config: DictConfig, - worker_group: RayWorkerGroup = None, - reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, - ): - instance = cls(config, worker_group, reward_loop_worker_handles) - await instance._async_init() - return instance - - async def _async_init(self): - await self._initialize_llm_servers_async() - self._init_agent_loop_workers() - - async def _initialize_llm_servers_async(self): - rollout_world_size = ( - self.config.actor_rollout_ref.rollout.tensor_model_parallel_size - * self.config.actor_rollout_ref.rollout.data_parallel_size - * self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size - ) - world_size = ( - self.worker_group.world_size - if self.worker_group - else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes - ) - num_replicas = world_size // rollout_world_size - - rollout_config = self.config.actor_rollout_ref.rollout - model_config = self.config.actor_rollout_ref.model - self.rollout_replicas = [ - self.rollout_replica_class( - replica_rank=replica_rank, - config=rollout_config, - model_config=model_config, - gpus_per_node=self.config.trainer.n_gpus_per_node, - ) - for replica_rank in range(num_replicas) - ] - - if self.worker_group: - await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas]) - else: - await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas]) - - self.server_handles = [server._server_handle for server in self.rollout_replicas] - self.server_addresses = [server._server_address for server in self.rollout_replicas] - - print(f"AgentLoopManager: {self.server_addresses}") - # Update Prometheus configuration with server addresses - if rollout_config.prometheus.enable: - if rollout_config.disable_log_stats: - raise ValueError("PROMETHEUS needs disable_log_stats==False, but it is currently True.") - await asyncio.to_thread( - update_prometheus_config, rollout_config.prometheus, self.server_addresses, rollout_config.name - ) - async def generate_single_sample_async( self, sample: DataProto, diff --git a/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py index cbfb4954f4f..6982184f8f6 100644 --- a/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py +++ b/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py @@ -30,9 +30,9 @@ class PartialSingleTurnAgentLoop(AgentLoopBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length - self.response_length = self.config.actor_rollout_ref.rollout.response_length - self.apply_chat_template_kwargs = self.config.data.get("apply_chat_template_kwargs", {}) + self.prompt_length = self.rollout_config.prompt_length + self.response_length = self.rollout_config.response_length + self.apply_chat_template_kwargs = self.data_config.get("apply_chat_template_kwargs", {}) async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: output: Optional[AgentLoopOutput] = kwargs.get("output", None) @@ -124,6 +124,8 @@ def get_prompt_ids(): "is_cancel": is_cancel, "param_version_start": param_version_start, "param_version_end": param_version_end, + "turn_scores": [], + "tool_rewards": [], }, multi_modal_data=multi_modal_data, # multi_modal_data={"image": image_data} if image_data is not None else {}, diff --git a/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py index 0082fc13bc8..370587f0364 100644 --- a/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py +++ b/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py @@ -33,9 +33,9 @@ class AsyncPartialToolAgentLoop(ToolAgentLoop): """ - def __init__(self, trainer_config, **kwargs): - super().__init__(trainer_config, **kwargs) - self.enable_partial_rollout = trainer_config.config.async_training.get("partial_rollout", False) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.enable_partial_rollout = self.config.async_training.get("partial_rollout", False) # async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: async def run( diff --git a/verl/experimental/fully_async_policy/base_detach_sync.py b/verl/experimental/fully_async_policy/base_detach_sync.py deleted file mode 100644 index c0924417d78..00000000000 --- a/verl/experimental/fully_async_policy/base_detach_sync.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import logging -import os -import threading - -import torch -from omegaconf import DictConfig -from ray.util.collective import collective - -from verl.single_controller.base.decorator import Dispatch, register -from verl.utils.device import get_torch_device, is_npu_available -from verl.utils.distributed import stateless_init_process_group - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class BaseDetachNcclSync: - _bucket_size_mb = 1024.0 - _sync_history = [] - _max_history_size = 20 - _last_avg_bucket_size = 1024.0 - - def __init__(self, config: DictConfig, role: str): - self._bg_loop = asyncio.new_event_loop() - self._bg_thread = threading.Thread( - target=self._start_background_loop, args=(self._bg_loop,), name="rollout_actor_async_worker", daemon=True - ) - self._bg_thread.start() - logger.info(f"[DetachNcclSync] Background thread for SGLang sync started. PID: {os.getpid()}") - - @classmethod - def get_bucket_size_mb(cls): - return cls._bucket_size_mb - - @classmethod - def get_last_avg_bucket_size(cls): - return cls._last_avg_bucket_size - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) - def get_last_avg_bucket_size_remote(self): - return BaseDetachNcclSync._last_avg_bucket_size - - @classmethod - def record_sync_metrics(cls, bucket_size_mb, sync_time): - """Dynamically adjust the bucket size based on past synchronization times.""" - bucket_size_mb_value = bucket_size_mb[0] if isinstance(bucket_size_mb, list) else bucket_size_mb - print(f"[DetachNcclSync] sync_metrics: bucket_size_mb={bucket_size_mb_value:.2f}MB, sync_time={sync_time:.2f}s") - cls._sync_history.append((bucket_size_mb_value, sync_time)) - if len(cls._sync_history) > cls._max_history_size: - cls._sync_history.pop(0) - - MIN_BUCKET_SIZE_MB = 512 - MAX_BUCKET_SIZE_MB = 8192 # 8GB - - if len(cls._sync_history) < 4: - cls._bucket_size_mb = min(MAX_BUCKET_SIZE_MB, cls._bucket_size_mb * 1.5) - else: - times = [t for _, t in cls._sync_history] - buckets = [b for b, _ in cls._sync_history] - recent_avg_time = sum(times[-2:]) / 2 - previous_avg_time = sum(times[-4:-2]) / 2 - recent_avg_bucket = sum(buckets[-2:]) / 2 - previous_avg_bucket = sum(buckets[-4:-2]) / 2 - - performance_improved = recent_avg_time < previous_avg_time - bucket_increased = recent_avg_bucket > previous_avg_bucket - time_change_ratio = ( - abs(recent_avg_time - previous_avg_time) / previous_avg_time if previous_avg_time > 0 else 0.0 - ) - - if time_change_ratio > 0.2: - increase_step, decrease_step = 1.2, 0.8 - elif time_change_ratio > 0.1: - increase_step, decrease_step = 1.1, 0.9 - elif time_change_ratio > 0.05: - increase_step, decrease_step = 1.05, 0.95 - else: - increase_step, decrease_step = 1.02, 0.98 - - should_increase = (performance_improved and bucket_increased) or ( - not performance_improved and not bucket_increased - ) - step = increase_step if should_increase else decrease_step - new_size = cls._bucket_size_mb * step - cls._bucket_size_mb = min(MAX_BUCKET_SIZE_MB, max(MIN_BUCKET_SIZE_MB, new_size)) - - def _start_background_loop(self, loop): - asyncio.set_event_loop(loop) - try: - loop.run_forever() - except Exception as e: - logger.error(f"[DetachNcclSync] Background loop crashed: {e}") - - def _run_async_safely(self, coro): - if not self._bg_thread.is_alive(): - raise RuntimeError("Background thread for SGLang sync is not running!") - - future = asyncio.run_coroutine_threadsafe(coro, self._bg_loop) - return future.result() - - def __del__(self): - if hasattr(self, "_bg_loop") and self._bg_loop.is_running(): - self._bg_loop.call_soon_threadsafe(self._bg_loop.stop) - if hasattr(self, "_bg_thread") and self._bg_thread.is_alive(): - self._bg_thread.join(timeout=1.0) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def init_checkpoint_engine(self, rank_offset: int, actor_num: int, rollout_num: int): - from .checkpoint_engine import CheckpointEngine - - current_rank = torch.distributed.get_rank() + rank_offset - actor_ranks = list(range(actor_num)) - rollout_ranks = [rank + actor_num for rank in range(rollout_num)] - assert rank_offset == 0 or rank_offset == actor_num - - self.checkpoint_engine = CheckpointEngine( - current_rank, actor_ranks, rollout_ranks, self.config.checkpoint_engine.device_buffer_size_M - ) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size): - rank = torch.distributed.get_rank() + rank_offset - self._weight_sync_group = stateless_init_process_group( - master_address, - master_port, - rank, - world_size, - get_torch_device().current_device(), - ) - - @staticmethod - def get_inference_model(rollout): - """ - Get models according to different types of inference_engine - Args: - rollout: rollout object - Returns: - model: model object (for vllm) or rollout object itself (for sglang) - """ - inference_engine = rollout.inference_engine - if hasattr(inference_engine, "llm_engine"): - inference_model = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model - elif hasattr(inference_engine, "worker"): - inference_model = inference_engine.worker.model_runner.model - else: - raise AttributeError( - f"Unsupported inference_engine type: {type(inference_engine)}. " - f"Expected LLM (with llm_engine attribute) or WorkerWrapperBase (with worker attribute)." - ) - return inference_model - - def _sync_sglang_weights(self, inference_model, params, sync_group_name): - bucket_size_bytes = int(self.get_bucket_size_mb() * 1024 * 1024) - actual_bucket_sizes = [] - current_batch = [] - current_batch_size = 0 - - def flush_batch(): - if current_batch: - actual_bucket_sizes.append(current_batch_size / (1024 * 1024)) - self._run_async_safely(self.update_weights(inference_model, iter(current_batch))) - get_torch_device().synchronize() - current_batch.clear() - - for key, shape, dtype in self._weights_info: - tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) - if self._is_actor: - assert key in params - origin_data = params[key] - if hasattr(origin_data, "full_tensor"): - origin_data = origin_data.full_tensor() - if torch.distributed.get_rank() == 0: - tensor.copy_(origin_data) - collective.broadcast(tensor, src_rank=0, group_name=sync_group_name) - - tensor_size = tensor.numel() * tensor.element_size() - current_batch.append((key, tensor)) - current_batch_size += tensor_size - - if current_batch_size >= bucket_size_bytes: - flush_batch() - current_batch_size = 0 - - flush_batch() - cls = type(self) - cls._last_avg_bucket_size = ( - sum(actual_bucket_sizes) / len(actual_bucket_sizes) if actual_bucket_sizes else self.get_bucket_size_mb() - ) - - # Resume kv_cache after weights sync to restore GPU memory released during pause - if self._is_rollout and self.rollout_device_mesh["infer_tp"].get_local_rank() == 0: - self._run_async_safely(inference_model.resume_memory_occupation(tags=["kv_cache"])) - - def _sync_vllm_weights(self, inference_model, params, sync_group_name): - for key, shape, dtype in self._weights_info: - tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) - if self._is_actor: - assert key in params - origin_data = params[key] - if hasattr(origin_data, "full_tensor"): - origin_data = origin_data.full_tensor() - if torch.distributed.get_rank() == 0: - tensor.copy_(origin_data) - if is_npu_available: - self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream()) - else: - collective.broadcast(tensor, src_rank=0, group_name=sync_group_name) - if self._is_rollout: - inference_model.load_weights([(key, tensor)]) - - async def update_weights(self, inference_engine, params): - from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights - - await sgl_update_weights( - engine=inference_engine, - params_batch=params, - device_mesh_key="infer_tp", - device_mesh=self.rollout_device_mesh, - ) - - if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0: - await inference_engine.flush_cache() diff --git a/verl/experimental/fully_async_policy/checkpoint_engine.py b/verl/experimental/fully_async_policy/checkpoint_engine.py deleted file mode 100644 index 28f932d61b3..00000000000 --- a/verl/experimental/fully_async_policy/checkpoint_engine.py +++ /dev/null @@ -1,522 +0,0 @@ -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This logic is largely copied from: -- https://github.com/MoonshotAI/checkpoint-engine -""" - -import concurrent.futures -import os -import re -import socket -import subprocess -import threading -from collections.abc import Callable -from functools import lru_cache -from typing import TYPE_CHECKING, Annotated, Any, TypedDict - -import torch -import zmq -from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema -from ray.util.collective import collective - -from verl.utils.device import ( - get_device_name, - get_torch_device, -) - -if TYPE_CHECKING: - from typing import TypeVar - - from typing_extensions import TypedDict - - class FileMeta(TypedDict): - key: str # parameter name - dtype: torch.dtype - shape: torch.Size - type: type - tp_concat_dim: int - - T = TypeVar("T") - - -def _dt_validate(value: Any) -> torch.dtype: - """Validate the input value to ensure it is a valid torch.dtype""" - if isinstance(value, str): - if not value.startswith("torch."): - raise ValueError(f"dtype {value} should start with torch.") - try: - value = getattr(torch, value.split(".")[1]) - except AttributeError as e: - raise ValueError(f"unknown dtype: {value}") from e - if not isinstance(value, torch.dtype): - raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}") - return value - - -# Annotated type for torch.dtype with validation and serialization -_TorchDtype = Annotated[ - torch.dtype, - PlainValidator(_dt_validate), - PlainSerializer(lambda x: str(x), return_type=str), - WithJsonSchema({"type": "string"}, mode="serialization"), -] - - -def _size_validate(value: Any) -> torch.Size: - """Validate the input value to ensure it is a valid torch.Size""" - if isinstance(value, list | tuple): - return torch.Size(value) - if not isinstance(value, torch.Size): - raise TypeError(f"size {value} should be torch.Size, got {type(value)}") - return value - - -# Annotated type for torch.Size with validation and serialization -_TorchSize = Annotated[ - torch.Size, - PlainValidator(_size_validate), - PlainSerializer(lambda x: tuple(x), return_type=tuple), - WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"), -] - - -def _tensor_validate(value: Any) -> torch.Tensor: - """Validate the input value to ensure it is a valid torch.Tensor""" - if isinstance(value, torch.Tensor): - return value - raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}") - - -# Annotated type for torch.Tensor with validation -_TorchTensor = Annotated[ - torch.Tensor, - PlainValidator(_tensor_validate), -] - - -class ParameterMeta(BaseModel): - """Metadata for a parameter including name, dtype, and shape""" - - name: str - dtype: _TorchDtype - shape: _TorchSize - - -class MemoryBuffer(BaseModel): - """ - MemoryBuffer assembles a group of parameter tensors into a single buffer, - and records the meta information of each original parameter. - """ - - buffer: _TorchTensor - size: int # size of buffer in bytes - metas: list[ParameterMeta] - - -class MemoryBufferMeta(BaseModel): - """The meta info of MemoryBuffer, but not store the buffer data""" - - size: int - metas: list[ParameterMeta] - - -# 256 bytes alignment when flatten torch tensors to uint8 buffer -_ALIGN_SIZE = 256 - - -def _align_size(dtype: torch.dtype, shape: torch.Size) -> int: - """ - Calculate the aligned size of a torch tensor - - If the tensor's size (in bytes) cannot be evenly divided by _ALIGN_SIZE, - it will be rounded up to the nearest multiple of _ALIGN_SIZE. - - Args: - dtype (torch.dtype): The data type of the tensor (e.g., torch.float32, torch.int64). - shape (torch.Size): The shape of the tensor, representing its dimensions. - - Returns: - int: The aligned size of the tensor in bytes. - """ - return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE - - -@lru_cache(maxsize=1) -def get_ip() -> str: - try: - # try to get ip from network interface - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(("8.8.8.8", 80)) - return s.getsockname()[0] - except Exception as e: # noqa: BLE001 - # fallback to get ip from hostname - print(f"fail to get ip from network interface, fallback to get ip from hostname: {e}") - return socket.gethostbyname(socket.gethostname()) - - -def npu_generate_uuid() -> str: - """Generate uuid for each npu device""" - str_pid = str(os.getpid()) - npu_num = 8 - try: - for npu_id in range(npu_num): - cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)] - result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603 - str_result = str(result.stdout) - if str_pid in str_result: - # In A3 server, one NPU has two chips. - match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result) - chip_count = int(match_chip_count.group(1)) - search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :] - match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid) - chip_id = int(match_chip_id.group(1)) - return f"{get_ip()}-{npu_id * chip_count + chip_id}" - raise ValueError("The current process is not running on the npu device") - except subprocess.CalledProcessError as e: - raise ValueError("The current process is not running on the npu device") from e - - -def _get_physical_device_id(device_index: int | None = None) -> str: - """ - Get the physical device (GPU or NPU) uuid of the current device - """ - try: - if get_device_name() == "npu": - return f"NPU-{npu_generate_uuid()}" - else: - return f"GPU-{get_torch_device().get_device_properties(device_index).uuid!s}" - except AssertionError as e: - raise ValueError(f"fail to get physical gpu id {device_index}") from e - - -class FlattenedTensorMetadata(TypedDict): - name: str - shape: torch.Size - dtype: torch.dtype - # specify the start offset of this tensor in shared ipc_buffer tensor - offset: int - - -def _to_flattened_tensor_meta(metas: list[ParameterMeta], offset: int = 0) -> list[FlattenedTensorMetadata]: - """ - compute the offset of each parameter in the buffer - - Args: - metas (list[ParameterMeta]): The list of parameter metas info - offset (int): The start offset of the buffer. Defaults to 0. - - Returns: - list[FlattenedTensorMetadata]: The list of FlattenedTensorMetadata: - """ - ret = [] - for meta in metas: - size = _align_size(meta.dtype, meta.shape) - ret.append( - { - "name": meta.name, - "dtype": meta.dtype, - "shape": meta.shape, - "offset": offset, - } - ) - offset += size - return ret - - -def _extract_weights( - flatten_metas: list[FlattenedTensorMetadata], buffer: torch.Tensor -) -> list[tuple[str, torch.Tensor]]: - """ - According to the flatten_metas and buffer, extract the weights - """ - - assert buffer is not None - weights: list[tuple[str, torch.Tensor]] = [] - for item in flatten_metas: - shape = item["shape"] - if isinstance(shape, list | tuple): - shape = torch.Size(shape) - assert isinstance(shape, torch.Size) - dtype, offset = item["dtype"], item["offset"] - size = dtype.itemsize * shape.numel() - tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) - weights.append((item["name"], tensor)) - return weights - - -class CheckpointEngine: - """ - CheckpointEngine class for control parameters synchronization. - Each trainer/rollout rank has a CheckpointEngine instance. - """ - - def __init__( - self, current_rank: int, actor_ranks: list[int], rollout_ranks: list[int], device_buffer_size_M: int - ) -> None: - self.current_rank = current_rank - self.actor_ranks = actor_ranks - self.rollout_ranks = rollout_ranks - # global_buckets saves the global MemoryBufferMeta infos. - # Thus each CheckpointEngine instance can control their operations in SPMD - self.global_buckets: dict[int, list[MemoryBufferMeta]] = None - # min device_buffer_size for h2d and broadcast - self.device_buffer_size_M = device_buffer_size_M - - # ipc config for broadcast in pipeline mode - self._zmq_ctx = zmq.Context() - self._zmq_addr_counter: int = 0 - device_index = self.current_rank % get_torch_device().device_count() - self._device_uuid = _get_physical_device_id(device_index) - - def register_checkpoint( - self, weights_info: list[tuple[str, torch.Size, torch.dtype]], cpu_named_params: dict[str, torch.Tensor] - ): - """ - Register checkpoint information and prepare memory buffers for parameter synchronization. - - This function organizes the parameters into memory buckets for efficient synchronization - and prepares pinned memory buffers for faster data transfer between CPU and device. - - Args: - weights_info (list[tuple[str, torch.Size, torch.dtype]]): - A list of tuples containing parameter name, shape, and data type. - cpu_named_params (dict[str, torch.Tensor]): - A dictionary mapping parameter names to their corresponding CPU tensors. - - Steps: - 1. Calculate the bucket size based on the largest parameter tensor size and the device buffer size. - 2. Organize parameters into global buckets for each actor rank, ensuring that the total size of each bucket - does not exceed the bucket size. - 3. For actor ranks, allocate pinned memory buffers for each bucket and copy the parameter tensors - into these buffers. - - Notes: - Each CheckpointEngine instance maintains the global buckets metas, - but stores part of parmas data in host memory - """ - bucket_size = max( - self.device_buffer_size_M << 20, max(_align_size(dtype, shape) for _, shape, dtype in weights_info) - ) - print( - f"set checkpoint_engine device buffer size: {self.device_buffer_size_M}M, " - f"and finally set it to {bucket_size >> 20}M considering the largest parameter tensor size" - ) - self.bucket_size = bucket_size - - # global_buckets saves the global MemoryBufferMeta infos. - if self.global_buckets is None: - self.global_buckets = {rank: [MemoryBufferMeta(size=0, metas=[])] for rank in self.actor_ranks} - - actor_ranks_size = len(self.actor_ranks) - assert actor_ranks_size > 0, f"actor_ranks:{self.actor_ranks} should not be empty" - for param_idx, (param_name, param_shape, param_dtype) in enumerate(weights_info): - # Each parameter is assigned to an actor rank, and only this rank will store it - assgin_rank = self.actor_ranks[param_idx % actor_ranks_size] - param_size = _align_size(param_dtype, param_shape) - - if self.global_buckets[assgin_rank][-1].size + param_size > bucket_size: - assert self.global_buckets[assgin_rank][-1].size, ( - f"global_buckets[{assgin_rank}][-1].size:{self.global_buckets[assgin_rank][-1].size}" - " should not be 0" - ) - self.global_buckets[assgin_rank].append(MemoryBufferMeta(size=0, metas=[])) - self.global_buckets[assgin_rank][-1].metas.append( - ParameterMeta(name=param_name, dtype=param_dtype, shape=param_shape) - ) - self.global_buckets[assgin_rank][-1].size += param_size - - def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]: - """Allocate pinned memory for a bucket.""" - buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) - return idx, buffer - - def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): - """Copy a tensor into a pinned memory buffer.""" - buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) - - memory_buffers = [] # for rollout rank, return empty buffer - if self.current_rank in self.actor_ranks: # is_actor - local_buckets = self.global_buckets[self.current_rank] - memory_buffers = [ - MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) for bucket in local_buckets - ] - - # Use thread pool to accelerate organize parameters into buckets - with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: - futures = [ - executor.submit(register_pin_memory, idx, bucket.size) for idx, bucket in enumerate(local_buckets) - ] - new_futures = [] - for future in concurrent.futures.as_completed(futures): - idx, buffer = future.result() - assert buffer.numel() == local_buckets[idx].size, ( - f"buffer numel {buffer.numel()} should be equal to bucket size {local_buckets[idx].size}" - ) - memory_buffers[idx].buffer = buffer - print( - f"[rank{self.current_rank}] register pin_memory for " - f" bucket {idx + 1}/{len(local_buckets)} finished, " - f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer" - ) - offset = 0 - for meta in local_buckets[idx].metas: - name = meta.name - tensor = cpu_named_params[name] - size = _align_size(tensor.dtype, tensor.shape) - assert size == _align_size(meta.dtype, meta.shape), ( - f"tensor {name} size {size} should be equal to " - f"meta size {_align_size(meta.dtype, meta.shape)}" - ) - new_futures.append(executor.submit(register_tensor, buffer, offset, tensor)) - offset += size - for future in concurrent.futures.as_completed(new_futures): - future.result() - - self.memory_buffers = memory_buffers - - def get_max_buckets_num_per_rank(self): - """ - Get the maximum number of buckets for all rank. - """ - assert self.global_buckets is not None - return max(len(buckets) for buckets in self.global_buckets.values()) - - def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]: - """ - Bind zmq socket for broadcast. - """ - - def zmq_handle(device_uuid: str) -> str: - return f"ipc://@checkpoint-engine-{device_uuid}-{self._zmq_addr_counter}.sock" - - socket_path = zmq_handle(self._device_uuid) - socket = self._zmq_ctx.socket(zmq.REQ) - socket.bind(socket_path) - self._zmq_addr_counter += 1 - return socket, socket_path - - def update_checkpoint(self, inference_model, group_name: str, overlap_broadcast_and_consume: bool = False): - """ - Update the checkpoint by broadcasting and loading weights. - - This function handles the synchronization of parameters across ranks by: - 1. Copying data from memory buffers to device buffers (h2d_buffer). - 2. Broadcasting the data to all ranks using collective communication. - 3. Loading the weights into the inference model if provided. - 4. Optionally, use a pipeline approach for broadcasting and loading weights. - - Args: - inference_model: The model to load weights into. If None (trainer rank), weights are only broadcasted. - group_name (str): The name of the collective communication group. - overlap_broadcast_and_consume (bool): Whether to use the pipeline approach - for broadcasting and loading weights. - """ - try: - h2d_buffer: torch.Tensor | None = ( - None - if self.current_rank in self.rollout_ranks - else torch.empty(self.bucket_size, dtype=torch.uint8, device=get_torch_device().current_device()) - ) - # for pipeline mode, we need to allocate 2x buffer size - broadcast_load_buffer = torch.empty( - self.bucket_size * (2 if overlap_broadcast_and_consume else 1), - dtype=torch.uint8, - device=get_torch_device().current_device(), - ) - except Exception: - print( - "allocate buffer for update_checkpoint failed, " - "you may need to reduce " - "config.async_training.checkpoint_engine.device_buffer_size_M" - ) - raise - - max_h2d_iter = self.get_max_buckets_num_per_rank() - - if overlap_broadcast_and_consume: - socket, socket_path = self._bind_zmq_socket() - - # Define a function to update weights from IPC - def update_weights_from_ipc_(socket_path): - zmq_ctx = zmq.Context() - socket = zmq_ctx.socket(zmq.REP) - socket.connect(socket_path) - socket.recv_pyobj() - socket.send(b"") - - while True: - payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj() - if payload is None: - # means the update is done - get_torch_device().synchronize() - socket.send(b"") - break - assert isinstance(payload, list) - if inference_model is not None: - inference_model.load_weights(_extract_weights(payload, broadcast_load_buffer)) - get_torch_device().synchronize() - socket.send(b"") - - req_thread = threading.Thread( - target=update_weights_from_ipc_, - args=(socket_path,), - ) - req_thread.start() - socket.send_pyobj(b"") - get_torch_device().synchronize() - - gidx = 0 - local_buckets = self.global_buckets.get(self.current_rank, []) - - for i in range(max_h2d_iter): - # Step 1: Each actor rank copy the parameter tensor into device memory - if i < len(self.memory_buffers): - h2d_buffer[: local_buckets[i].size].data.copy_(self.memory_buffers[i].buffer) - - # Step 2: Broadcast the device data in turn - for broadcast_rank, _buckets in self.global_buckets.items(): - if i >= len(_buckets): - continue - bucket = _buckets[i] - - # Prepare the broadcast buffer - start = gidx % 2 * self.bucket_size if overlap_broadcast_and_consume else 0 - buffer_b: torch.Tensor = broadcast_load_buffer[start : start + bucket.size] - if broadcast_rank == self.current_rank: - buffer_b.data.copy_(h2d_buffer[: bucket.size]) - - # Broadcast the buffer to all ranks - collective.broadcast(buffer_b, src_rank=broadcast_rank, group_name=group_name) - - if overlap_broadcast_and_consume: - socket.recv() - collective.barrier(group_name=group_name) - socket.send_pyobj(_to_flattened_tensor_meta(bucket.metas, start)) - elif inference_model is not None: - named_tensor = _to_flattened_tensor_meta(bucket.metas, 0) - inference_model.load_weights(_extract_weights(named_tensor, buffer_b)) - - gidx += 1 - - if overlap_broadcast_and_consume: - socket.recv() - socket.send_pyobj(None) - socket.recv() - req_thread.join() - socket.close() - - collective.barrier(group_name=group_name) - # clear host memory cache - self.memory_buffers = [] diff --git a/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml b/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml index eece540865c..5ef1bc6c813 100644 --- a/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml +++ b/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml @@ -6,6 +6,9 @@ defaults: - ppo_megatron_trainer - _self_ +trainer: + use_legacy_worker_impl: disable + async_training: # Maximum samples staleness threshold @@ -25,17 +28,6 @@ async_training: use_trainer_do_validate: False - # checkpoint_engine config for accelerating parameter synchronization between rollouter and trainer - checkpoint_engine: - # Whether to use checkpoint_engine - enable: True - - # Device buffer size for checkpoint_engine, default is 4096 MB - device_buffer_size_M: 4096 - - # Enable the pipeline for broadcasting and updating parameters, but it requires more device memory - overlap_broadcast_and_consume: False - # Rollout config rollout: @@ -62,17 +54,15 @@ data: gen_batch_size: 1 actor_rollout_ref: - # checkpoint_engine config for accelerating parameter synchronization between rollouter and trainer - checkpoint_engine: ${oc.select:async_training.checkpoint_engine, null} rollout: # Must be turned off! Otherwise, Parameter synchronization cannot be performed. free_cache_engine: False # Must be enabled! Otherwise, log_probs cannot be calculated. calculate_log_probs: True - # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced. - # TODO: Can be removed in the future once parameter synchronization is ready. - load_format: "auto" + + checkpoint_engine: + backend: "nccl" actor: # Must use rollout log probs for training diff --git a/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml b/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml index 7dece1cd479..1f4b4db8c82 100644 --- a/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml +++ b/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml @@ -6,6 +6,9 @@ defaults: - ppo_trainer - _self_ +trainer: + use_legacy_worker_impl: disable + async_training: # Maximum samples staleness threshold @@ -24,18 +27,6 @@ async_training: # whether to use trainer do_validate use_trainer_do_validate: False - - # checkpoint_engine config for accelerating parameter synchronization between rollouter and trainer - checkpoint_engine: - # Whether to use checkpoint_engine - enable: True - - # Device buffer size for checkpoint_engine, default is 4096 MB - device_buffer_size_M: 4096 - - # Enable the pipeline for broadcasting and updating parameters, but it requires more device memory - overlap_broadcast_and_consume: False - # Rollout config rollout: @@ -62,17 +53,15 @@ data: gen_batch_size: 1 actor_rollout_ref: - # checkpoint_engine config for accelerating parameter synchronization between rollouter and trainer - checkpoint_engine: ${oc.select:async_training.checkpoint_engine, null} rollout: # Must be turned off! Otherwise, Parameter synchronization cannot be performed. free_cache_engine: False # Must be enabled! Otherwise, log_probs cannot be calculated. calculate_log_probs: True - # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced. - # TODO: Can be removed in the future once parameter synchronization is ready. - load_format: "auto" + + checkpoint_engine: + backend: "nccl" actor: # Must use rollout log probs for training @@ -82,4 +71,4 @@ actor_rollout_ref: # And it can be used in conjunction with other rollout_correction algorithms. algorithm: rollout_correction: - bypass_mode: True \ No newline at end of file + bypass_mode: True diff --git a/verl/experimental/fully_async_policy/fsdp_workers.py b/verl/experimental/fully_async_policy/fsdp_workers.py deleted file mode 100644 index 227ae06307d..00000000000 --- a/verl/experimental/fully_async_policy/fsdp_workers.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import time - -import torch -import torch.distributed -from omegaconf import DictConfig -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -from verl.experimental.fully_async_policy.base_detach_sync import BaseDetachNcclSync -from verl.experimental.fully_async_policy.fsdp2_utils import fsdp2_sharded_load_from_cpu, fsdp2_sharded_save_to_cpu -from verl.single_controller.base.decorator import Dispatch, register -from verl.utils.device import ( - get_device_name, - get_torch_device, -) -from verl.utils.fsdp_utils import ( - fsdp_version, - load_fsdp_model_to_gpu, - offload_fsdp_model_to_cpu, -) -from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker, CriticWorker - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -device_name = get_device_name() - -__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"] - - -class DetachNcclSync(BaseDetachNcclSync, AsyncActorRolloutRefWorker): - def __init__(self, config: DictConfig, role: str): - BaseDetachNcclSync.__init__(self, config, role) - AsyncActorRolloutRefWorker.__init__(self, config, role) - - def _get_actor_params(self): - pass - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def sync_rollout_weights(self, sync_group_name="actor_rollout"): - assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine - assert hasattr(self, "_weights_info") and self._weights_info is not None - - if self._is_actor and self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - params = self._get_actor_params() if self._is_actor else None - rollout_name = self.config.rollout.name - - inference_model = None - if self._is_rollout: - if rollout_name == "vllm": - inference_model = BaseDetachNcclSync.get_inference_model(self.rollout) - - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - patch_vllm_moe_model_weight_loader(inference_model) - elif rollout_name == "sglang": - inference_model = self.rollout._engine - # For ServerAdapter, _engine might be None and needs async initialization - if inference_model is None: - # Initialize the server adapter engine - print("[sync_rollout_weights] Initialize server adapter engine") - - async def init_engine(): - if hasattr(self.rollout, "_init_server_adapter"): - await self.rollout._init_server_adapter() - else: - print("[sync_rollout_weights] No _init_server_adapter method found") - return self.rollout._engine - - inference_model = self._run_async_safely(init_engine()) - # For ServerAdapter, only TP rank 0 initializes the engine - # TP rank != 0 can safely have inference_model as None - from verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter - - is_server_adapter = isinstance(self.rollout, ServerAdapter) - is_non_tp_rank = False - if ( - is_server_adapter - and hasattr(self.rollout, "device_mesh") - and self.rollout.device_mesh is not None - ): - try: - is_non_tp_rank = self.rollout.device_mesh["infer_tp"].get_local_rank() != 0 - except Exception: - pass - - if inference_model is None and not (is_server_adapter and is_non_tp_rank): - raise RuntimeError( - f"Failed to initialize rollout engine. " - f"rollout type: {type(self.rollout)}, " - f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}" - ) - else: - raise NotImplementedError(f"Unknown rollout name: {rollout_name}") - - if rollout_name == "sglang" and self._is_rollout: - self._sync_sglang_weights(inference_model, params, sync_group_name) - else: - self._sync_vllm_weights(inference_model, params, sync_group_name) - - if self._is_actor and self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - get_torch_device().empty_cache() - - def cache_actor_weights_to_cpu(self): - self.cpu_named_params = {} - if self._is_actor: - params = self._get_actor_params() - local_rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - - for tensor_idx, (key, _, _) in enumerate(self._weights_info): - origin_data = params[key] - if hasattr(origin_data, "full_tensor"): - origin_data = origin_data.full_tensor() - - if tensor_idx % world_size == local_rank: - self.cpu_named_params[key] = origin_data.to("cpu", non_blocking=True) - get_torch_device().synchronize() - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): - assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine - assert hasattr(self, "_weights_info") and self._weights_info is not None - - # Load model to GPU - load_start_time = time.time() - if self._is_actor and self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - load_duration = time.time() - load_start_time - - from ray.util.collective import collective - - # Cache actor weights to CPU and measure the time taken - cache_start_time = time.time() - self.cache_actor_weights_to_cpu() - cache_end_time = time.time() - cache_duration = cache_end_time - cache_start_time - - # Register the cached weights into the checkpoint engine - self.checkpoint_engine.register_checkpoint(self._weights_info, self.cpu_named_params) - register_end_time = time.time() - register_duration = register_end_time - cache_end_time - self.cpu_named_params = {} - - collective.barrier(group_name=sync_group_name) - update_start_time = time.time() - - inference_model = None - if self._is_rollout: - inference_model = BaseDetachNcclSync.get_inference_model(self.rollout) - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - patch_vllm_moe_model_weight_loader(inference_model) - - # Update the checkpoint with the inference model and broadcast weights - self.checkpoint_engine.update_checkpoint( - inference_model=inference_model, - group_name=sync_group_name, - overlap_broadcast_and_consume=self.config.checkpoint_engine.overlap_broadcast_and_consume, - ) - - update_end_time = time.time() - update_duration = update_end_time - update_start_time - - offload_start_time = time.time() - if self._is_actor and self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - offload_duration = time.time() - offload_start_time - - print( - f"sync_rollout_weights_by_checkpoint finish!, rank:{torch.distributed.get_rank()}," - f" is_actor:{self._is_actor}, is_rollout:{self._is_rollout}," - f" total cost:{update_end_time - cache_start_time} seconds, while cache cost {cache_duration} seconds, " - f" register cost {register_duration} seconds, update cost {update_duration} seconds" - ) - - if self._is_actor and self._is_offload_param: - print( - f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds," - f" offload model to cpu cost {offload_duration} seconds" - ) - - -class DetachActorWorker(DetachNcclSync): - def __init__(self, config: DictConfig, role: str): - print("[DetachAsyncRolloutWorker] Initializing via DetachNcclSync...") - DetachNcclSync.__init__(self, config, role) - - def _get_actor_params(self): - assert self._is_actor - params = self.actor_module_fsdp.state_dict() - from verl.utils.model import convert_weight_keys - - params = convert_weight_keys( - params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) - ) - return params - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def get_actor_weights_info(self): - assert self._is_actor - if hasattr(self, "_weights_info"): - return self._weights_info - if fsdp_version(self.actor_module_fsdp) == 1: - from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType - - FSDP.set_state_dict_type( - self.actor_module_fsdp, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig(), - ) - params = self._get_actor_params() - ret = [] - for key, tensor in params.items(): - ret.append((key, tensor.size(), tensor.dtype)) - self._weights_info = ret - return ret - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_model_to_cpu(self, n): - if not hasattr(self, "cpu_saved_models"): - self.cpu_saved_models = {} - self.cpu_saved_models[n] = fsdp2_sharded_save_to_cpu(self.actor_module_fsdp) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def restore_model_from_cpu(self, n): - if n in self.cpu_saved_models: - cpu_sharded_state, global_spec = self.cpu_saved_models[n] - fsdp2_sharded_load_from_cpu(self.actor_module_fsdp, cpu_sharded_state, global_spec) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def clear_cpu_model(self, n): - if n in self.cpu_saved_models: - del self.cpu_saved_models[n] - - -class DetachAsyncRolloutWorker(DetachNcclSync): - def __init__(self, config: DictConfig, role: str): - print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") - DetachNcclSync.__init__(self, config, role) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def set_actor_weights_info(self, weights_info): - assert self._is_rollout - self._weights_info = weights_info diff --git a/verl/experimental/fully_async_policy/fully_async_main.py b/verl/experimental/fully_async_policy/fully_async_main.py index ceeba563b18..4e9e509475f 100644 --- a/verl/experimental/fully_async_policy/fully_async_main.py +++ b/verl/experimental/fully_async_policy/fully_async_main.py @@ -82,45 +82,19 @@ def create_role_worker_mapping(config): # Select worker class based on strategy use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") if use_legacy_worker_impl == "disable": - from verl.experimental.separation.engine_workers import ( - DetachActorWorker, - DetachAsyncRolloutWorker, - TrainingWorker, - ) + from verl.experimental.separation.engine_workers import DetachActorWorker from verl.single_controller.ray import RayWorkerGroup + from verl.workers.engine_workers import TrainingWorker ray_worker_group_cls = RayWorkerGroup CriticWorker = TrainingWorker else: - if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.experimental.fully_async_policy.fsdp_workers import ( - CriticWorker, - DetachActorWorker, - DetachAsyncRolloutWorker, - ) - from verl.single_controller.ray import RayWorkerGroup - - ray_worker_group_cls = RayWorkerGroup - - elif config.actor_rollout_ref.actor.strategy == "megatron": - assert config.critic.strategy == "megatron" - from verl.experimental.fully_async_policy.megatron_worker import ( - CriticWorker, - DetachActorWorker, - DetachAsyncRolloutWorker, - ) - from verl.single_controller.ray import RayWorkerGroup - - ray_worker_group_cls = RayWorkerGroup - else: - raise NotImplementedError(f"Unsupported strategy: {config.actor_rollout_ref.actor.strategy}") + raise NotImplementedError("Fully async policy does not support legacy worker implementation") train_role = Role.ActorRollout if config.async_training.use_trainer_do_validate else Role.Actor role_worker_mapping = { train_role: ray.remote(DetachActorWorker), - Role.Rollout: ray.remote(DetachAsyncRolloutWorker), Role.Critic: ray.remote(CriticWorker), } @@ -177,14 +151,14 @@ def _initialize_components(self, config) -> None: print("[ASYNC MAIN] Creating FullyAsyncRollouter and FullyAsyncTrainer in parallel...") with ThreadPoolExecutor(max_workers=2) as executor: - rollouter_future = executor.submit(self._create_rollouter, config) - rollouter_future.result() - # TODO: keep _create_rollouter and _create_trainer parallel + # Rollouter does not permit continuous allocation, so we allocate trainer first. trainer_future = executor.submit(self._create_trainer, config) - # Wait for both to complete trainer_future.result() + rollouter_future = executor.submit(self._create_rollouter, config) + rollouter_future.result() + # sync total_train_steps between rollouter and trainer total_train_steps = ray.get(self.components["rollouter"].get_total_train_steps.remote()) print(f"total_train_steps {total_train_steps}") @@ -233,7 +207,7 @@ def _create_rollouter(self, config) -> None: rollouter = FullyAsyncRollouter.remote( config=config, tokenizer=self.components["tokenizer"], - role_worker_mapping={Role.Rollout: self.components["role_worker_mapping"][Role.Rollout]}, + role_worker_mapping=None, resource_pool_manager=create_resource_pool_manager(config, roles=[Role.Rollout]), ray_worker_group_cls=self.components["ray_worker_group_cls"], processor=self.components["processor"], @@ -313,6 +287,9 @@ def main(config): from time import time start_time = time() + # TODO: unify rollout config with actor_rollout_ref + config.actor_rollout_ref.rollout.nnodes = config.rollout.nnodes + config.actor_rollout_ref.rollout.n_gpus_per_node = config.rollout.n_gpus_per_node run_ppo(config, task_runner_class=FullyAsyncTaskRunner) print(f"total time: {time() - start_time:.2f} seconds") diff --git a/verl/experimental/fully_async_policy/fully_async_rollouter.py b/verl/experimental/fully_async_policy/fully_async_rollouter.py index 023b23a5c56..ce21be271cf 100644 --- a/verl/experimental/fully_async_policy/fully_async_rollouter.py +++ b/verl/experimental/fully_async_policy/fully_async_rollouter.py @@ -32,7 +32,7 @@ ) from verl.experimental.fully_async_policy.message_queue import MessageQueueClient from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer -from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.ray import RayWorkerGroup from verl.trainer.ppo.ray_trainer import ResourcePoolManager from verl.trainer.ppo.utils import Role, WorkerType from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path @@ -98,8 +98,20 @@ def __init__( from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler from verl.utils.dataset.rl_dataset import collate_fn - train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) - val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + max_samples=config.data.get("val_max_samples", -1), + ) train_sampler = create_rl_sampler(config.data, train_dataset) self._validate_config() @@ -217,6 +229,10 @@ def get_rollout_wg(self): """Get rollout worker group""" return self.rollout_wg + def get_replicas(self): + """Get rollout worker group""" + return self.async_rollout_manager.rollout_replicas + def get_max_queue_size(self): return self.max_queue_size @@ -402,23 +418,13 @@ async def init_workers(self): 2. Worker groups for each role (actor, critic, etc.) """ self._init_async_objects() - self._init_resource_pools() self._create_worker_classes() - self._init_worker_groups() - self._init_models() self._init_reward_loop() await self._init_async_rollout_manager() def _create_actor_rollout_classes(self): - # only create rollout - for role in [Role.Rollout]: - resource_pool = self.resource_pool_manager.get_resource_pool(role) - role_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[role], - config=self.config.actor_rollout_ref, - role=str(role), - ) - self.resource_pool_to_cls[resource_pool][str(role)] = role_cls + # Skip rollout creation and let agentloop handle it + pass def _init_models(self): self.rollout_wg = self.all_wg[str(Role.Rollout)] @@ -755,12 +761,10 @@ async def pause(self): await asyncio.gather(*self.active_tasks, return_exceptions=True) self.active_tasks.clear() print("[FullyAsyncRollouter][Public][Pause] All active tasks completed") - - # TODO use checkpoint engine for rollout clear_kv_cache - # print("[FullyAsyncRollouter][Public][Pause] clear kv cache") - # # Always clear KV cache to release GPU memory during weight synchronization, - # # regardless of partial_rollout setting. - # await self.async_rollout_manager.clear_kv_cache() + print("[FullyAsyncRollouter][Public][Pause] Prefix cache reset") + # Always clear KV cache to release GPU memory during weight synchronization, + # regardless of partial_rollout setting. + await self.async_rollout_manager.clear_kv_cache() self.monitor_loop_trigger = False async def resume(self, dependency_ref: ObjectRef = None): diff --git a/verl/experimental/fully_async_policy/fully_async_trainer.py b/verl/experimental/fully_async_policy/fully_async_trainer.py index e0019be46a5..9519c594dbd 100644 --- a/verl/experimental/fully_async_policy/fully_async_trainer.py +++ b/verl/experimental/fully_async_policy/fully_async_trainer.py @@ -422,6 +422,7 @@ def _fit_generate(self, batch: DataProto = None) -> DataProto: if batch is None: raise TrainingStopException("Training terminated: queue returned None") self._collect_metrics_from_samples(batch, metrics) + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature return batch def _compute_old_log_prob(self, batch: DataProto): @@ -561,8 +562,6 @@ def _save_checkpoint(self): def load_checkpoint(self): if self.config.trainer.resume_mode == "disable": - # NOTE: while there is no checkpoint to load, we still need to offload the model and optimizer to CPU - self.actor_rollout_wg.load_checkpoint(None) return 0 # load from hdfs @@ -578,8 +577,6 @@ def load_checkpoint(self): # find global_step_folder if self.config.trainer.resume_mode == "auto": if global_step_folder is None: - print("[FullyAsyncTrainer] Training from scratch") - self.actor_rollout_wg.load_checkpoint(None) return 0 else: if self.config.trainer.resume_mode == "resume_path": diff --git a/verl/experimental/fully_async_policy/megatron_worker.py b/verl/experimental/fully_async_policy/megatron_worker.py deleted file mode 100644 index e0a1c2a437c..00000000000 --- a/verl/experimental/fully_async_policy/megatron_worker.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Meituan Ltd. and/or its affiliates -# Copyright 2025 NVIDIA Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import time - -import torch -import torch.distributed -from omegaconf import DictConfig - -from verl.experimental.fully_async_policy.base_detach_sync import BaseDetachNcclSync -from verl.experimental.fully_async_policy.megatron_utils import ( - copy_megatron_model_to_cpu, - restore_megatron_model_from_cpu, -) -from verl.single_controller.base.decorator import Dispatch, register -from verl.utils.device import ( - get_device_name, - get_torch_device, -) -from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator -from verl.workers.megatron_workers import AsyncActorRolloutRefWorker, CriticWorker - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -device_name = get_device_name() - -__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"] - - -class DetachNcclSync(BaseDetachNcclSync, AsyncActorRolloutRefWorker): - def __init__(self, config: DictConfig, role: str): - BaseDetachNcclSync.__init__(self, config, role) - - AsyncActorRolloutRefWorker.__init__(self, config, role) - - def _get_actor_params(self): - pass - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def sync_rollout_weights(self, sync_group_name="actor_rollout"): - assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine - assert hasattr(self, "_weights_info") and self._weights_info is not None - if self._is_actor and self._is_offload_param: - load_megatron_model_to_gpu(self.actor_module, False) - params_generator = self._get_actor_params_generator() if self._is_actor else None - params = {key: tensor for key, tensor in params_generator} if params_generator is not None else None - - rollout_name = self.config.rollout.name - inference_model = None - if self._is_rollout and (not self._is_actor): - if rollout_name == "vllm": - inference_model = BaseDetachNcclSync.get_inference_model(self.rollout) - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - patch_vllm_moe_model_weight_loader(inference_model) - elif rollout_name == "sglang": - inference_model = self.rollout._engine - if inference_model is None: - print("[sync_rollout_weights] Initialize server adapter engine") - - async def init_engine(): - if hasattr(self.rollout, "_init_server_adapter"): - await self.rollout._init_server_adapter() - else: - print("[sync_rollout_weights] No _init_server_adapter method found") - return self.rollout._engine - - inference_model = self._run_async_safely(init_engine()) - # For ServerAdapter, only TP rank 0 initializes the engine - # TP rank != 0 can safely have inference_model as None - from verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter - - is_server_adapter = isinstance(self.rollout, ServerAdapter) - is_non_tp_rank = False - if ( - is_server_adapter - and hasattr(self.rollout, "device_mesh") - and self.rollout.device_mesh is not None - ): - try: - is_non_tp_rank = self.rollout.device_mesh["infer_tp"].get_local_rank() != 0 - except Exception: - pass - - if inference_model is None and not (is_server_adapter and is_non_tp_rank): - raise RuntimeError( - f"Failed to initialize rollout engine. " - f"rollout type: {type(self.rollout)}, " - f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}" - ) - else: - raise NotImplementedError(f"Unknown rollout name: {rollout_name}") - - if rollout_name == "sglang" and self._is_rollout: - self._sync_sglang_weights(inference_model, params, sync_group_name) - else: - self._sync_vllm_weights(inference_model, params, sync_group_name) - - if self._is_actor and self._is_offload_param: - offload_megatron_model_to_cpu(self.actor_module) - get_torch_device().empty_cache() - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_model_to_cpu(self, n): - if not hasattr(self, "cpu_saved_models"): - self.cpu_saved_models = {} - self.cpu_saved_models[n] = copy_megatron_model_to_cpu(self.actor.actor_module) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def restore_model_from_cpu(self, n): - if n in self.cpu_saved_models: - restore_megatron_model_from_cpu(self.actor.actor_module, self.cpu_saved_models[n]) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def clear_cpu_model(self, n): - if n in self.cpu_saved_models: - del self.cpu_saved_models[n] - - def cache_actor_weights_to_cpu(self): - self.cpu_named_params = {} - if self._is_actor: - params_generator = self._get_actor_params_generator() - local_rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - print(f"cache_actor_weights_to_cpu, local_rank:{local_rank}, world_size:{world_size}") - for tensor_idx, (key, tensor) in enumerate(params_generator): - if tensor_idx % world_size == local_rank: - self.cpu_named_params[key] = tensor.to("cpu", non_blocking=True) - get_torch_device().synchronize() - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): - assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine - assert hasattr(self, "_weights_info") and self._weights_info is not None - - # Load model to GPU - load_start_time = time.time() - if self._is_actor and self._is_offload_param: - load_megatron_model_to_gpu(self.actor_module, False) - load_duration = time.time() - load_start_time - - from ray.util.collective import collective - - # Cache actor weights to CPU and measure the time taken - cache_start_time = time.time() - self.cache_actor_weights_to_cpu() - cache_end_time = time.time() - cache_duration = cache_end_time - cache_start_time - - # Register the cached weights into the checkpoint engine - self.checkpoint_engine.register_checkpoint(self._weights_info, self.cpu_named_params) - register_end_time = time.time() - register_duration = register_end_time - cache_end_time - self.cpu_named_params = {} - - collective.barrier(group_name=sync_group_name) - update_start_time = time.time() - - rollout_name = self.config.rollout.name - inference_model = None - if self._is_rollout and (not self._is_actor): - if rollout_name == "vllm": - inference_model = BaseDetachNcclSync.get_inference_model(self.rollout) - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - patch_vllm_moe_model_weight_loader(inference_model) - elif rollout_name == "sglang": - inference_model = self.rollout._engine - # For ServerAdapter, _engine might be None and needs async initialization - if inference_model is None: - # Initialize the server adapter engine - print("[sync_rollout_weights] Initialize server adapter engine") - - async def init_engine(): - if hasattr(self.rollout, "_init_server_adapter"): - await self.rollout._init_server_adapter() - else: - print("[sync_rollout_weights] No _init_server_adapter method found") - return self.rollout._engine - - inference_model = self._run_async_safely(init_engine()) - # For ServerAdapter, only TP rank 0 initializes the engine - # TP rank != 0 can safely have inference_model as None - from verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter - - is_server_adapter = isinstance(self.rollout, ServerAdapter) - is_non_tp_rank = False - if ( - is_server_adapter - and hasattr(self.rollout, "device_mesh") - and self.rollout.device_mesh is not None - ): - try: - is_non_tp_rank = self.rollout.device_mesh["infer_tp"].get_local_rank() != 0 - except Exception: - pass - - if inference_model is None and not (is_server_adapter and is_non_tp_rank): - raise RuntimeError( - f"Failed to initialize rollout engine. " - f"rollout type: {type(self.rollout)}, " - f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}" - ) - else: - raise NotImplementedError(f"Unknown rollout name: {rollout_name}") - # Update the checkpoint with the inference model and broadcast weights - self.checkpoint_engine.update_checkpoint( - inference_model=inference_model, - group_name=sync_group_name, - overlap_broadcast_and_consume=self.config.checkpoint_engine.overlap_broadcast_and_consume, - ) - - update_end_time = time.time() - update_duration = update_end_time - update_start_time - - collective.barrier(group_name=sync_group_name) - offload_start_time = time.time() - if self._is_actor and self._is_offload_param: - offload_megatron_model_to_cpu(self.actor_module) - offload_duration = time.time() - offload_start_time - - print( - f"sync_rollout_weights_by_checkpoint finish!, rank:{torch.distributed.get_rank()}," - f" is_actor:{self._is_actor}, is_rollout:{self._is_rollout}," - f" total cost:{update_end_time - cache_start_time} seconds, while cache cost {cache_duration} seconds, " - f" register cost {register_duration} seconds, update cost {update_duration} seconds" - ) - - if self._is_actor and self._is_offload_param: - print( - f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds," - f" offload model to cpu cost {offload_duration} seconds" - ) - - -class DetachActorWorker(DetachNcclSync): - def __init__(self, config: DictConfig, role: str): - print("[DetachAsyncRolloutWorker] Initializing via DetachNcclSync...") - DetachNcclSync.__init__(self, config, role) - - def _get_actor_params_generator(self): - assert self._is_actor - if self.bridge is not None: - if self.vanilla_bridge: - generator = self.bridge.export_weights(self.actor.actor_module) - else: - generator = self.bridge.export_hf_weights(self.actor.actor_module) - else: - generator = per_tensor_generator( - self.actor.actor_module, - self.actor_model_config, - self.weight_converter, - self.tf_config, - self.layer_name_mapping, - ) - - return generator - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def get_actor_weights_info(self): - assert self._is_actor - if hasattr(self, "_weights_info"): - return self._weights_info - if self._is_offload_param: - load_megatron_model_to_gpu(self.actor_module, False) - params_generator = self._get_actor_params_generator() - ret = [] - for key, tensor in params_generator: - ret.append((key, tensor.size(), tensor.dtype)) - - self._weights_info = ret - # Here, we only call this function at the beginning, - # and immediately afterwards we call sync_rollout_weights. - # So we no longer call offload in this. - return ret - - -class DetachAsyncRolloutWorker(DetachNcclSync): - def __init__(self, config: DictConfig, role: str): - print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") - DetachNcclSync.__init__(self, config, role) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def set_actor_weights_info(self, weights_info): - assert self._is_rollout - self._weights_info = weights_info diff --git a/verl/experimental/fully_async_policy/param_sync.py b/verl/experimental/fully_async_policy/param_sync.py index 000568d12d6..88982efef15 100644 --- a/verl/experimental/fully_async_policy/param_sync.py +++ b/verl/experimental/fully_async_policy/param_sync.py @@ -18,6 +18,8 @@ import ray from ray.util.collective import collective +from verl.checkpoint_engine import CheckpointEngineManager +from verl.utils.config import omega_conf_to_dataclass from verl.utils.device import get_nccl_backend logger = logging.getLogger(__name__) @@ -50,11 +52,11 @@ def __init__(self, config, trainer, rollouter, mq): # Statistics self.current_version = 0 - self._init_weights_info() - self._init_sync_group() - - if self.config.async_training.checkpoint_engine.enable: - self._init_actor_rollout_checkpoint_engine() + replicas = ray.get(rollouter.get_replicas.remote()) + checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine) + self.checkpoint_manager = CheckpointEngineManager( + config=checkpoint_engine_config, trainer=self.actor_wg, replicas=replicas + ) def get_current_param_version(self) -> int: """Get current parameter version number""" @@ -98,22 +100,6 @@ def _init_sync_group(self): group_name=self.sync_group_name, ) - def _init_actor_rollout_checkpoint_engine(self): - ray.get( - self.actor_wg.init_checkpoint_engine( - rank_offset=0, - actor_num=len(self.actor_wg.workers), - rollout_num=len(self.rollout_wg.workers), - ) - ) - ray.get( - self.rollout_wg.init_checkpoint_engine( - rank_offset=len(self.actor_wg.workers), - actor_num=len(self.actor_wg.workers), - rollout_num=len(self.rollout_wg.workers), - ) - ) - def sync_weights(self, version, validate=False, global_steps=0, use_trainer_do_validate=False): """Sync weights between trainer and rollouter, and update parameter version""" start_time = time.time() @@ -130,16 +116,7 @@ def sync_weights(self, version, validate=False, global_steps=0, use_trainer_do_v # sync weights # For sglang, always use sync_rollout_weights instead of sync_rollout_weights_by_checkpoint - # TODO use checkpoint engine for sglang rollout - # rollout_name = getattr(self.config.actor_rollout_ref.rollout, "name", None) - # use_checkpoint_engine = self.config.async_training.checkpoint_engine.enable and rollout_name != "sglang" - # if use_checkpoint_engine: - # self.actor_wg.sync_rollout_weights_by_checkpoint(self.sync_group_name) - # ray.get(self.rollout_wg.sync_rollout_weights_by_checkpoint(self.sync_group_name)) - # else: - # self.actor_wg.sync_rollout_weights(self.sync_group_name) - # ray.get(self.rollout_wg.sync_rollout_weights(self.sync_group_name)) - + self.checkpoint_manager.update_weights() end_time = time.time() print( f"[ParameterSynchronizer] sync_weights success. cost {end_time - start_time:.2f} seconds, " diff --git a/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh b/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh index e50abb4dd5d..cc936f50dc1 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh @@ -103,7 +103,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ data.val_files="${TEST_FILE}" \ data.prompt_key=prompt \ data.truncation='left' \ - actor_rollout_ref.actor.strategy=fsdp \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp \ critic.strategy=fsdp \ data.max_prompt_length=${max_prompt_length} \ data.max_response_length=${max_response_length} \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh index cc1a9bb65c0..2a5eb1bb966 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh @@ -98,7 +98,7 @@ python3 -m verl.experimental.fully_async_policy.fully_async_main \ actor_rollout_ref.actor.use_dynamic_bsz=True \ actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh index b076456bc27..ba8e6804fdb 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh @@ -90,7 +90,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh index 203e4c03bd9..5561208ee6d 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh @@ -90,7 +90,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh index 7f29d44e225..242a5117a5e 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh @@ -92,7 +92,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh index cd2906a0053..ee0657eace7 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh @@ -92,7 +92,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh index 81b50382c32..002c1206b8a 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh @@ -90,7 +90,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh index 7cc55f632aa..f01fb8184e7 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh @@ -96,7 +96,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh index ddb177681be..2b2143ffa21 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh @@ -90,7 +90,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh index f00b41085c5..c04a09d3266 100644 --- a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh +++ b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh @@ -131,7 +131,6 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - async_training.compute_prox_log_prob=True \ algorithm.rollout_correction.rollout_is=${rollout_is} \ algorithm.rollout_correction.rollout_rs=${rollout_rs} \ algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ diff --git a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml index cb2f8c2054c..0e4677be368 100644 --- a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml +++ b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml @@ -6,6 +6,9 @@ defaults: - ppo_megatron_trainer - _self_ +trainer: + use_legacy_worker_impl: disable + # config for the rollout (only for resource isolation) rollout: # Number of nodes used in the rollout @@ -20,9 +23,9 @@ actor_rollout_ref: free_cache_engine: False # Must be enabled! Otherwise, log_probs cannot be calculated. calculate_log_probs: True - # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced. - # TODO: Can be removed in the future once parameter synchronization is ready. - load_format: "auto" + + checkpoint_engine: + backend: "nccl" # Only then will the use of log probs be correct. # And it can be used in conjunction with other rollout_correction algorithms. diff --git a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml index 012745e2aa3..dc784b2ae73 100644 --- a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml +++ b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml @@ -6,6 +6,9 @@ defaults: - ppo_trainer - _self_ +trainer: + use_legacy_worker_impl: disable + # config for the rollout (only for resource isolation) rollout: # Number of nodes used in the rollout @@ -20,9 +23,9 @@ actor_rollout_ref: free_cache_engine: False # Must be enabled! Otherwise, log_probs cannot be calculated. calculate_log_probs: True - # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced. - # TODO: Can be removed in the future once parameter synchronization is ready. - load_format: "auto" + + checkpoint_engine: + backend: "nccl" # Only then will the use of log probs be correct. # And it can be used in conjunction with other rollout_correction algorithms. diff --git a/verl/experimental/one_step_off_policy/fsdp_workers.py b/verl/experimental/one_step_off_policy/fsdp_workers.py deleted file mode 100644 index 67584a198bc..00000000000 --- a/verl/experimental/one_step_off_policy/fsdp_workers.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os - -import torch -import torch.distributed -from omegaconf import DictConfig -from ray.util.collective import collective -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -from verl.experimental.one_step_off_policy.distributed_utils import vllm_stateless_init_process_group -from verl.single_controller.base.decorator import Dispatch, register -from verl.utils.device import ( - get_device_name, - get_torch_device, -) -from verl.utils.fsdp_utils import ( - fsdp_version, - load_fsdp_model_to_gpu, - offload_fsdp_model_to_cpu, -) -from verl.utils.ray_utils import get_event_loop -from verl.workers.fsdp_workers import ( - ActorRolloutRefWorker, - AsyncActorRolloutRefWorker, - CriticWorker, -) - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -device_name = get_device_name() - -__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"] - - -class DetachSync(AsyncActorRolloutRefWorker): - def _get_actor_params(self): - pass - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size): - rank = torch.distributed.get_rank() + rank_offset - self._weight_sync_group = vllm_stateless_init_process_group( - master_address, - master_port, - rank, - world_size, - get_torch_device().current_device(), - ) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def sync_rollout_weights(self): - assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine - assert hasattr(self, "_weights_info") and self._weights_info is not None - - if self._is_actor and self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - params = self._get_actor_params() if self._is_actor else None - - rollout_name = self.config.rollout.name - if self._is_rollout: - if rollout_name == "vllm": - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - inference_model = self.rollout.inference_engine.worker.model_runner.model - patch_vllm_moe_model_weight_loader(inference_model) - elif rollout_name == "sglang": - inference_model = self.rollout._engine - else: - raise NotImplementedError(f"Unknown rollout name: {rollout_name}") - loop = get_event_loop() - for key, shape, dtype in self._weights_info: - tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) - if self._is_actor: - assert key in params - origin_data = params[key] - if hasattr(origin_data, "full_tensor"): - origin_data = origin_data.full_tensor() - if torch.distributed.get_rank() == 0: - tensor.copy_(origin_data) - - if device_name == "npu": - self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream()) - else: - collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") - - if self._is_rollout: - if rollout_name == "vllm": - inference_model.load_weights([(key, tensor)]) - elif rollout_name == "sglang": - # first_rank_in_node = self._tp_rank % tp_size_per_node == 0, - # Only the first rank within each node (i.e., the local rank is 0) initializes the engine; - # engines for other ranks are set to None. - - if inference_model is not None: - loop.run_until_complete(self.update_weights(inference_model, [(key, tensor)])) - - if self._is_actor and self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - get_torch_device().empty_cache() - - async def update_weights(self, inference_engine, params): - from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights - - await sgl_update_weights( - engine=inference_engine, - params_batch=params, - device_mesh_key="infer_tp", - device_mesh=self.rollout_device_mesh, - ) - - if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0: - await inference_engine.flush_cache() - - -class DetachActorWorker(DetachSync): - def _get_actor_params(self): - assert self._is_actor - params = self.actor_module_fsdp.state_dict() - from verl.utils.model import convert_weight_keys - - params = convert_weight_keys( - params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) - ) - return params - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def get_actor_weights_info(self): - assert self._is_actor - if hasattr(self, "_weights_info"): - return self._weights_info - if fsdp_version(self.actor_module_fsdp) == 1: - from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType - - FSDP.set_state_dict_type( - self.actor_module_fsdp, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig(), - ) - params = self._get_actor_params() - ret = [] - for key, tensor in params.items(): - ret.append((key, tensor.size(), tensor.dtype)) - self._weights_info = ret - return ret - - -class DetachAsyncRolloutWorker(DetachSync): - def __init__(self, config: DictConfig, role: str): - print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") - ActorRolloutRefWorker.__init__(self, config, role) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def set_actor_weights_info(self, weights_info): - assert self._is_rollout - self._weights_info = weights_info diff --git a/verl/experimental/one_step_off_policy/main_ppo.py b/verl/experimental/one_step_off_policy/main_ppo.py index 6a7405d6c0f..0c6ecaedf0e 100644 --- a/verl/experimental/one_step_off_policy/main_ppo.py +++ b/verl/experimental/one_step_off_policy/main_ppo.py @@ -64,10 +64,6 @@ def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager: assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0" assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0" - rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes - resource_pool_spec["rollout_pool"] = rollout_pool - mapping[Role.Rollout] = "rollout_pool" - return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) @@ -81,48 +77,15 @@ def create_role_worker_mapping(config): Returns: dict: Mapping from roles to worker classes """ - # Select worker class based on strategy - use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") - if use_legacy_worker_impl == "disable": - from verl.experimental.separation.engine_workers import ( - DetachActorWorker, - DetachAsyncRolloutWorker, - TrainingWorker, - ) - from verl.single_controller.ray import RayWorkerGroup - - ray_worker_group_cls = RayWorkerGroup - - CriticWorker = TrainingWorker - else: - if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.experimental.one_step_off_policy.fsdp_workers import ( - CriticWorker, - DetachActorWorker, - DetachAsyncRolloutWorker, - ) - from verl.single_controller.ray import RayWorkerGroup - - ray_worker_group_cls = RayWorkerGroup - - elif config.actor_rollout_ref.actor.strategy == "megatron": - assert config.critic.strategy == "megatron" - from verl.experimental.one_step_off_policy.megatron_workers import ( - CriticWorker, - DetachActorWorker, - DetachAsyncRolloutWorker, - ) - from verl.single_controller.ray import RayWorkerGroup - - ray_worker_group_cls = RayWorkerGroup - else: - raise NotImplementedError(f"Unsupported strategy: {config.actor_rollout_ref.actor.strategy}") + from verl.experimental.separation.engine_workers import DetachActorWorker + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.engine_workers import TrainingWorker + + ray_worker_group_cls = RayWorkerGroup role_worker_mapping = { Role.Actor: ray.remote(DetachActorWorker), - Role.Rollout: ray.remote(DetachAsyncRolloutWorker), - Role.Critic: ray.remote(CriticWorker), + Role.Critic: ray.remote(TrainingWorker), } # Add reference policy (if KL loss or reward is required) @@ -219,6 +182,10 @@ def main(config): # Automatically set `config.trainer.device = npu` when running on Ascend NPU. auto_set_device(config) + # TODO: unify rollout config with actor_rollout_ref + config.actor_rollout_ref.rollout.nnodes = config.rollout.nnodes + config.actor_rollout_ref.rollout.n_gpus_per_node = config.rollout.n_gpus_per_node + run_ppo(config, task_runner_class=OneStepTaskRunner) print(f"total time: {time() - start_time:.2f} seconds") diff --git a/verl/experimental/one_step_off_policy/megatron_workers.py b/verl/experimental/one_step_off_policy/megatron_workers.py deleted file mode 100644 index d5c0a18cfea..00000000000 --- a/verl/experimental/one_step_off_policy/megatron_workers.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os - -import torch -import torch.distributed -from omegaconf import DictConfig -from ray.util.collective import collective - -from verl.experimental.one_step_off_policy.distributed_utils import vllm_stateless_init_process_group -from verl.single_controller.base.decorator import Dispatch, register -from verl.utils.device import ( - get_device_name, - get_torch_device, -) -from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu -from verl.utils.ray_utils import get_event_loop -from verl.workers.megatron_workers import ( - ActorRolloutRefWorker, - AsyncActorRolloutRefWorker, - CriticWorker, -) - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -device_name = get_device_name() - -__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"] - - -class DetachSync(AsyncActorRolloutRefWorker): - def _get_actor_params(self): - pass - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size): - rank = torch.distributed.get_rank() + rank_offset - self._weight_sync_group = vllm_stateless_init_process_group( - master_address, - master_port, - rank, - world_size, - get_torch_device().current_device(), - ) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def sync_rollout_weights(self): - assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine - assert hasattr(self, "_weights_info") and self._weights_info is not None - - params_generator = self._get_actor_params_generator() if self._is_actor else None - - if self._is_actor and self._is_offload_param: - load_megatron_model_to_gpu(self.actor_module) - - rollout_name = self.config.rollout.name - if self._is_rollout: - if rollout_name == "vllm": - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - inference_model = self.rollout.inference_engine.worker.model_runner.model - patch_vllm_moe_model_weight_loader(inference_model) - elif rollout_name == "sglang": - inference_model = self.rollout._engine - else: - raise NotImplementedError(f"Unknown rollout name: {rollout_name}") - - loop = get_event_loop() - for key, shape, dtype in self._weights_info: - if self._is_actor: - weight_key, weight = next(params_generator) - assert key == weight_key - assert shape == weight.size() - assert dtype == weight.dtype - - tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) - if self._is_actor and torch.distributed.get_rank() == 0: - tensor.copy_(weight) - - if device_name == "npu": - self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream()) - else: - collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") - - if self._is_rollout: - if rollout_name == "vllm": - inference_model.load_weights([(key, tensor)]) - elif rollout_name == "sglang": - # first_rank_in_node = self._tp_rank % tp_size_per_node == 0, - # Only the first rank within each node (i.e., the local rank is 0) initializes the engine; - # engines for other ranks are set to None. - - if inference_model is not None: - loop.run_until_complete(self.update_weights(inference_model, [(key, tensor)])) - - if self._is_actor and self._is_offload_param: - offload_megatron_model_to_cpu(self.actor_module) - - async def update_weights(self, inference_engine, params): - from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights - - await sgl_update_weights( - engine=inference_engine, - params_batch=params, - device_mesh_key="infer_tp", - device_mesh=self.rollout_device_mesh, - ) - - if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0: - await inference_engine.flush_cache() - - -class DetachActorWorker(DetachSync): - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def _get_actor_params_generator(self): - assert self._is_actor - from verl.models.mcore import get_mcore_weight_converter - from verl.utils.megatron_utils import per_tensor_generator - - layer_name_mapping = { - "qkv_layer_name": "self_attention.linear_qkv.", - "gate_proj_layer_name": "linear_fc1.", - } - weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) - generator = per_tensor_generator( - self.actor.actor_module, - self.actor_model_config, - weight_converter, - self.tf_config, - layer_name_mapping, - ) - return generator - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def get_actor_weights_info(self): - assert self._is_actor - if hasattr(self, "_weights_info"): - return self._weights_info - if self._is_offload_param: - load_megatron_model_to_gpu(self.actor_module) - params_generator = self._get_actor_params_generator() - ret = [] - for key, tensor in params_generator: - ret.append((key, tensor.size(), tensor.dtype)) - - self._weights_info = ret - # Here, we only call this function at the beginning, - # and immediately afterwards we call sync_rollout_weights. - # So we no longer call offload in this. - return ret - - -class DetachAsyncRolloutWorker(DetachSync): - def __init__(self, config: DictConfig, role: str): - print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") - ActorRolloutRefWorker.__init__(self, config, role) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def set_actor_weights_info(self, weights_info): - assert self._is_rollout - self._weights_info = weights_info diff --git a/verl/experimental/one_step_off_policy/ray_trainer.py b/verl/experimental/one_step_off_policy/ray_trainer.py index a0bdc0150b3..144632dead5 100644 --- a/verl/experimental/one_step_off_policy/ray_trainer.py +++ b/verl/experimental/one_step_off_policy/ray_trainer.py @@ -27,7 +27,6 @@ import ray import torch from omegaconf import OmegaConf -from ray.util.collective import collective from torch.utils.data import Dataset, Sampler from tqdm import tqdm @@ -88,6 +87,8 @@ def __init__( self.hybrid_engine = config.actor_rollout_ref.hybrid_engine assert not self.hybrid_engine + # Skip rollout worker mapping and let agentloop create it. + role_worker_mapping.pop(Role.Rollout, None) self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager self.use_reference_policy = need_reference_policy(self.config) @@ -138,14 +139,8 @@ def __init__( self.reward_tensor = None self.reward_extra_infos_dict = {} - def _validate(self): - self.actor_rollout_wg = self.rollout_wg - ret = super()._validate() - self.actor_rollout_wg = self.actor_wg - return ret - def _create_actor_rollout_classes(self): - for role in [Role.Actor, Role.Rollout]: + for role in [Role.Actor]: resource_pool = self.resource_pool_manager.get_resource_pool(role) role_cls = RayClassWithInitArgs( cls=self.role_worker_mapping[role], @@ -169,13 +164,8 @@ def _init_models(self): self.rm_wg.init_model() self.actor_wg = self.all_wg[str(Role.Actor)] - self.rollout_wg = self.all_wg[str(Role.Rollout)] self.actor_wg.init_model() - self.rollout_wg.init_model() self.actor_rollout_wg = self.actor_wg - weights_info = self.actor_wg.get_actor_weights_info()[0] - self.rollout_wg.set_actor_weights_info(weights_info) - self._create_weight_sync_group() def _init_async_rollout_manager(self): # infrastructure overview: https://verl.readthedocs.io/en/latest/advance/reward_loop.html#architecture-design @@ -192,47 +182,10 @@ def _init_async_rollout_manager(self): from verl.experimental.one_step_off_policy.agent_loop import OneStepOffAgentLoopManager self.async_rollout_mode = True - self.async_rollout_manager = OneStepOffAgentLoopManager( - config=self.config, worker_group=self.rollout_wg, reward_loop_worker_handles=reward_loop_worker_handles + self.async_rollout_manager = OneStepOffAgentLoopManager.create( + config=self.config, reward_loop_worker_handles=reward_loop_worker_handles ) - def _create_weight_sync_group(self): - from verl.utils.device import get_nccl_backend - - actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers - n_workers = len(actor_rollout_workers) - - if self.device_name == "npu": - master_address = ray.get(self.actor_wg.workers[0]._get_node_ip.remote()).strip("[]") - master_port = ray.get(self.actor_wg.workers[0]._get_free_port.remote()) - self.actor_wg.create_weight_sync_group( - master_address, - master_port, - 0, - n_workers, - ) - ray.get( - self.rollout_wg.create_weight_sync_group( - master_address, - master_port, - len(self.actor_wg.workers), - n_workers, - ) - ) - else: - # Create Ray collective group for fallback communication - collective.create_collective_group( - actor_rollout_workers, - n_workers, - list(range(0, n_workers)), - backend=get_nccl_backend(), - group_name="actor_rollout", - ) - - def sync_rollout_weights(self): - self.actor_wg.sync_rollout_weights() - ray.get(self.rollout_wg.sync_rollout_weights()) - def _create_continuous_iterator(self): """ Create a continuous data iterator across epoch @@ -434,6 +387,7 @@ async def _fit_generate(self, batch_data_future, continuous_iterator): with marked_timer("gen", timing_raw, color="red"): _metrics, _timing_raw, epoch, batch, future_reward = await batch_data_future + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature timing_raw.update(batch.meta_info["timing"]) timing_raw.update(_timing_raw) metrics.update(_metrics) @@ -452,8 +406,3 @@ async def _fit_generate(self, batch_data_future, continuous_iterator): batch_data_future = None return batch, batch_data_future - - def _fit_update_weights(self): - # TODO: use checkpoint engine to update weight - # self.sync_rollout_weights() - pass diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_4_12.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_4_12.sh index 7101009e07b..cbefe87424b 100644 --- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_4_12.sh +++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_4_12.sh @@ -73,7 +73,7 @@ python3 -m verl.experimental.one_step_off_policy.main_ppo \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64.sh index 2cd5ae46d9c..c35513cf9f2 100644 --- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64.sh +++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64.sh @@ -75,7 +75,7 @@ python -m verl.experimental.one_step_off_policy.main_ppo \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64_ris.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64_ris.sh index 8775bdfa0f5..10ce9122269 100644 --- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64_ris.sh +++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_64_64_ris.sh @@ -85,7 +85,7 @@ python -m verl.experimental.one_step_off_policy.main_ppo \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_colocate.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_colocate.sh index b44fd6b25e6..a5c6ee87143 100644 --- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_colocate.sh +++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_colocate.sh @@ -68,7 +68,7 @@ python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_4_12.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_4_12.sh index e56d2f90dd5..2725bb5bc3d 100644 --- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_4_12.sh +++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_4_12.sh @@ -73,7 +73,7 @@ python3 -m verl.experimental.one_step_off_policy.main_ppo \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_colocate.sh b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_colocate.sh index 3c18460f1e2..5ccceec5f9d 100644 --- a/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_colocate.sh +++ b/verl/experimental/one_step_off_policy/shell/dapo_7b_math_fsdp2_sglang_colocate.sh @@ -68,7 +68,7 @@ python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ diff --git a/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6.sh b/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6.sh index b2dfa578ed7..facabdf58e8 100644 --- a/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6.sh +++ b/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6.sh @@ -26,7 +26,7 @@ python3 -m verl.experimental.one_step_off_policy.main_ppo \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.actor.optim.lr=1e-6 \ diff --git a/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh b/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh index 1f5f72e6bcc..5c959f49961 100644 --- a/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh +++ b/verl/experimental/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh @@ -26,7 +26,7 @@ python3 -m verl.experimental.one_step_off_policy.main_ppo \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.actor.optim.lr=1e-6 \ diff --git a/verl/experimental/one_step_off_policy/shell/grpo_3b_gsm8k_fsdp2_2_6.sh b/verl/experimental/one_step_off_policy/shell/grpo_3b_gsm8k_fsdp2_2_6.sh index b94a66f588b..c5c5eb11d2a 100644 --- a/verl/experimental/one_step_off_policy/shell/grpo_3b_gsm8k_fsdp2_2_6.sh +++ b/verl/experimental/one_step_off_policy/shell/grpo_3b_gsm8k_fsdp2_2_6.sh @@ -25,7 +25,7 @@ python3 -m verl.experimental.one_step_off_policy.main_ppo \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ - actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \ critic.strategy=fsdp2 \ actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.actor.optim.lr=1e-6 \ diff --git a/verl/experimental/reward_loop/reward_loop.py b/verl/experimental/reward_loop/reward_loop.py index 9261db23719..712cd9bde7a 100644 --- a/verl/experimental/reward_loop/reward_loop.py +++ b/verl/experimental/reward_loop/reward_loop.py @@ -38,7 +38,6 @@ def migrate_legacy_reward_impl(config): """ Migrate the legacy reward model implementation to the new one. - This is a temporary fix. A more robust one will be added. """ # 1. reward workers migration # config.reward_model.num_workers -> config.reward.num_workers @@ -49,7 +48,7 @@ def migrate_legacy_reward_impl(config): # config.reward_model.reward_manager -> config.reward.reward_manager if config.reward_model.reward_manager is not None: config.reward.reward_manager.name = config.reward_model.reward_manager - if config.reward_model.get("reward_loop_source") is not None: + if config.reward_model.reward_loop_source is not None: config.reward.reward_manager.source = config.reward_model.reward_loop_source config.reward.reward_manager.module.path = config.reward_model.reward_loop_module_path config.reward.reward_manager.module.name = config.reward_model.reward_loop_class_name @@ -64,19 +63,29 @@ def migrate_legacy_reward_impl(config): for key in ["enable", "enable_resource_pool", "n_gpus_per_node", "nnodes"]: if config.reward_model.get(key) is not None: config.reward.reward_model[key] = config.reward_model[key] - # for dapo reward kwargs + if config.reward_model.model.path is not None: + config.reward.reward_model.model_path = config.reward_model.model.path + # config.reward_model.reward_kwargs -> config.reward.reward_kwargs (for dapo algo) if config.reward_model.get("reward_kwargs") is not None: - with open_dict(config.reward.reward_model): - config.reward.reward_model["reward_kwargs"] = config.reward_model["reward_kwargs"] + with open_dict(config.reward): + config.reward["reward_kwargs"] = config.reward_model["reward_kwargs"] + # config.reward_model.rollout -> config.reward.reward_model.rollout legacy_rollout = config.reward_model.rollout - if not all(v is None for v in legacy_rollout.values()): - config.reward.reward_model.rollout = legacy_rollout + for key in legacy_rollout.keys(): + if legacy_rollout[key] is not None: + config.reward.reward_model.rollout[key] = legacy_rollout[key] # 5. sandbox_fusion migration # config.sandbox_fusion -> reward.sandbox_fusion if not all(v is None for v in config.sandbox_fusion.values()): config.reward.sandbox_fusion = config.sandbox_fusion + # 6. delete legacy config from configs + with open_dict(config): + del config.reward_model + del config.custom_reward_function + del config.sandbox_fusion + return config @@ -222,12 +231,10 @@ async def compute_score_disrm(self, data: DataProto) -> dict: engine_name = self.config.reward.reward_model.rollout.name model_name = self.config.reward.reward_model.model_path if engine_name == "vllm": - # TODO (dyy): the "activation" has been changed to "use_activation" in vllm 0.11.2 payloads = { "model": model_name, "input": disrm_prompt, - "activation": False, - # "add_special_tokens": False, # vllm >= 0.11.2 + "use_activation": False, } output = await self._post_request(payloads, "classify") rm_score = output["data"][-1]["probs"][-1] diff --git a/verl/experimental/separation/engine_workers.py b/verl/experimental/separation/engine_workers.py index 5e98052bb79..0f8062ff888 100644 --- a/verl/experimental/separation/engine_workers.py +++ b/verl/experimental/separation/engine_workers.py @@ -15,290 +15,58 @@ import logging import os -import time -import torch -import torch.distributed from omegaconf import DictConfig -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from verl.experimental.fully_async_policy.base_detach_sync import BaseDetachNcclSync from verl.single_controller.base.decorator import Dispatch, register from verl.utils.device import ( get_device_name, - get_torch_device, ) -from verl.utils.fsdp_utils import fsdp_version -from verl.utils.megatron_utils import per_tensor_generator -from verl.workers.engine_workers import ActorRolloutRefWorker, TrainingWorker +from verl.workers.engine_workers import ActorRolloutRefWorker logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) device_name = get_device_name() -__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "TrainingWorker"] +__all__ = ["DetachActorWorker"] -class DetachNcclSync(BaseDetachNcclSync, ActorRolloutRefWorker): - def __init__(self, config: DictConfig, role: str): - BaseDetachNcclSync.__init__(self, config, role) - ActorRolloutRefWorker.__init__(self, config, role) - - def _get_actor_params(self): - pass - - def load_model_to_gpu(self): - if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: - from verl.utils.fsdp_utils import load_fsdp_model_to_gpu - - load_fsdp_model_to_gpu(self.actor_module_fsdp) - elif self.config.actor_rollout_ref.actor.strategy == "megatron": - from verl.utils.megatron_utils import load_megatron_model_to_gpu - - load_megatron_model_to_gpu(self.actor_module, False) - else: - raise NotImplementedError(f"Unsupported strategy: {self.config.actor_rollout_ref.actor.strategy}") - - def offload_model_to_cpu(self): - if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: - from verl.utils.fsdp_utils import offload_fsdp_model_to_cpu - - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - elif self.config.actor_rollout_ref.actor.strategy == "megatron": - from verl.utils.megatron_utils import offload_megatron_model_to_cpu - - offload_megatron_model_to_cpu(self.actor_module) - else: - raise NotImplementedError(f"Unsupported strategy: {self.config.actor_rollout_ref.actor.strategy}") - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def sync_rollout_weights(self, sync_group_name="actor_rollout"): - # TODO: Refator this function for the chekpoint engine - assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine - assert hasattr(self, "_weights_info") and self._weights_info is not None - - if self._is_actor and self.engine._is_offload_param: - self.load_model_to_gpu() - - if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: - params = self._get_actor_params() if self._is_actor else None - - elif self.config.actor_rollout_ref.actor.strategy == "megatron": - params_generator = self._get_actor_params_generator() if self._is_actor else None - params = {key: tensor for key, tensor in params_generator} if params_generator is not None else None - - else: - raise NotImplementedError(f"Unsupported strategy: {self.config.actor_rollout_ref.actor.strategy}") - - rollout_name = self.config.rollout.name - - inference_model = None - if self._is_rollout and (not self._is_actor): - if rollout_name == "vllm": - inference_model = BaseDetachNcclSync.get_inference_model(self.rollout) - - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - patch_vllm_moe_model_weight_loader(inference_model) - elif rollout_name == "sglang": - inference_model = self.rollout._engine - # For ServerAdapter, _engine might be None and needs async initialization - if inference_model is None: - # Initialize the server adapter engine - print("[sync_rollout_weights] Initialize server adapter engine") - - async def init_engine(): - if hasattr(self.rollout, "_init_server_adapter"): - await self.rollout._init_server_adapter() - else: - print("[sync_rollout_weights] No _init_server_adapter method found") - return self.rollout._engine - - inference_model = self._run_async_safely(init_engine()) - if inference_model is None: - raise RuntimeError( - f"Failed to initialize rollout engine. " - f"rollout type: {type(self.rollout)}, " - f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}" - ) - else: - raise NotImplementedError(f"Unknown rollout name: {rollout_name}") - - if rollout_name == "sglang" and self._is_rollout: - self._sync_sglang_weights(inference_model, params, sync_group_name) - else: - self._sync_vllm_weights(inference_model, params, sync_group_name) - - if self._is_actor and self.engine._is_offload_param: - self.offload_model_to_cpu() - get_torch_device().empty_cache() - - def cache_actor_weights_to_cpu(self): - # TODO: Refator this function for the chekpoint engine - self.cpu_named_params = {} - local_rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - if self._is_actor: - if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: - params = self._get_actor_params() - for tensor_idx, (key, _, _) in enumerate(self._weights_info): - origin_data = params[key] - if hasattr(origin_data, "full_tensor"): - origin_data = origin_data.full_tensor() - - if tensor_idx % world_size == local_rank: - self.cpu_named_params[key] = origin_data.to("cpu", non_blocking=True) - - elif self.config.actor_rollout_ref.actor.strategy == "megatron": - params_generator = self._get_actor_params_generator() - print(f"cache_actor_weights_to_cpu, local_rank:{local_rank}, world_size:{world_size}") - for tensor_idx, (key, tensor) in enumerate(params_generator): - if tensor_idx % world_size == local_rank: - self.cpu_named_params[key] = tensor.to("cpu", non_blocking=True) - else: - raise NotImplementedError(f"Unsupported strategy: {self.config.actor_rollout_ref.actor.strategy}") - - get_torch_device().synchronize() - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): - # TODO: Refator this function for the chekpoint engine - assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine - assert hasattr(self, "_weights_info") and self._weights_info is not None - - # Load model to GPU - load_start_time = time.time() - if self._is_actor and self.engine._is_offload_param: - self.load_model_to_gpu() - load_duration = time.time() - load_start_time - - from ray.util.collective import collective - - # Cache actor weights to CPU and measure the time taken - cache_start_time = time.time() - self.cache_actor_weights_to_cpu() - cache_end_time = time.time() - cache_duration = cache_end_time - cache_start_time - - # Register the cached weights into the checkpoint engine - self.checkpoint_engine.register_checkpoint(self._weights_info, self.cpu_named_params) - register_end_time = time.time() - register_duration = register_end_time - cache_end_time - self.cpu_named_params = {} - - collective.barrier(group_name=sync_group_name) - update_start_time = time.time() - - rollout_name = self.config.rollout.name - inference_model = None - if self._is_rollout and (not self._is_actor): - if rollout_name == "vllm": - inference_model = BaseDetachNcclSync.get_inference_model(self.rollout) - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - patch_vllm_moe_model_weight_loader(inference_model) - elif rollout_name == "sglang": - if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: - raise NotImplementedError( - "Fully async sglang backend does not support " - f"actor strategy: {self.config.actor_rollout_ref.actor.strategy}" - ) +class DetachActorWorker(ActorRolloutRefWorker): + """ + A worker class that extends ActorRolloutRefWorker to support detaching and restoring the actor model. - inference_model = self.rollout._engine - # For ServerAdapter, _engine might be None and needs async initialization - if inference_model is None: - # Initialize the server adapter engine - print("[sync_rollout_weights] Initialize server adapter engine") + This worker facilitates saving the model state to CPU and restoring it, enabling efficient + resource management and checkpointing in distributed training. It currently supports + FSDP, FSDP2, and Megatron strategies. + """ - async def init_engine(): - if hasattr(self.rollout, "_init_server_adapter"): - await self.rollout._init_server_adapter() - else: - print("[sync_rollout_weights] No _init_server_adapter method found") - return self.rollout._engine - - inference_model = self._run_async_safely(init_engine()) - if inference_model is None: - raise RuntimeError( - f"Failed to initialize rollout engine. " - f"rollout type: {type(self.rollout)}, " - f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}" - ) - else: - raise NotImplementedError(f"Unknown rollout name: {rollout_name}") - - # Update the checkpoint with the inference model and broadcast weights - self.checkpoint_engine.update_checkpoint( - inference_model=inference_model, - group_name=sync_group_name, - overlap_broadcast_and_consume=self.config.checkpoint_engine.overlap_broadcast_and_consume, - ) - - update_end_time = time.time() - update_duration = update_end_time - update_start_time - - if self.config.actor_rollout_ref.actor.strategy == "megatron": - collective.barrier(group_name=sync_group_name) - - offload_start_time = time.time() - if self._is_actor and self.engine._is_offload_param: - self.offload_model_to_cpu() - offload_duration = time.time() - offload_start_time - - print( - f"sync_rollout_weights_by_checkpoint finish!, rank:{torch.distributed.get_rank()}," - f" is_actor:{self._is_actor}, is_rollout:{self._is_rollout}," - f" total cost:{update_end_time - cache_start_time} seconds, while cache cost {cache_duration} seconds, " - f" register cost {register_duration} seconds, update cost {update_duration} seconds" - ) - - if self._is_actor and self.engine._is_offload_param: - print( - f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds," - f" offload model to cpu cost {offload_duration} seconds" - ) - - -class DetachActorWorker(DetachNcclSync): def __init__(self, config: DictConfig, role: str): - print("[DetachAsyncRolloutWorker] Initializing via DetachNcclSync...") - DetachNcclSync.__init__(self, config, role) + """ + Initialize the DetachActorWorker. + + Args: + config: Configuration dictionary. + role: The role of the worker (e.g., 'actor', 'rollout', 'ref'). + """ + ActorRolloutRefWorker.__init__(self, config, role) self._strategy_handlers = None self.copy_handler, self.restore_handler = self._get_strategy_handlers() - def _get_actor_params(self): - # TODO: Refator this function for the chekpoint engine - assert self._is_actor - params = self.actor_module_fsdp.state_dict() - from verl.utils.model import convert_weight_keys - - params = convert_weight_keys( - params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) - ) - return params - - def _get_actor_params_generator(self): - # TODO: Refator this function for the chekpoint engine - assert self._is_actor - if self.bridge is not None: - generator = self.bridge.export_weights(self.actor.actor_module) - else: - generator = per_tensor_generator( - self.actor.actor_module, - self.actor_model_config, - self.weight_converter, - self.tf_config, - self.layer_name_mapping, - ) + def _get_strategy_handlers(self): + """ + Get the strategy-specific handlers for saving and restoring the model. - return generator + Returns: + tuple: A tuple containing (save_handler, restore_handler). - def _get_strategy_handlers(self): + Raises: + NotImplementedError: If the strategy is not supported. + """ if self._strategy_handlers is not None: return self._strategy_handlers - strategy = self.config.actor_rollout_ref.actor.strategy + strategy = self.config.actor.strategy if strategy in ["fsdp", "fsdp2"]: from verl.experimental.fully_async_policy.fsdp2_utils import ( @@ -319,47 +87,14 @@ def _get_strategy_handlers(self): return self._strategy_handlers - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def get_actor_weights_info(self): - # TODO: Refator this function for the chekpoint engine - assert self._is_actor - if hasattr(self, "_weights_info"): - return self._weights_info - - ret = [] - - if self.config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: - if fsdp_version(self.actor_module_fsdp) == 1: - from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType - - FSDP.set_state_dict_type( - self.actor_module_fsdp, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig(), - ) - params = self._get_actor_params() - - for key, tensor in params.items(): - ret.append((key, tensor.size(), tensor.dtype)) - - elif self.config.actor_rollout_ref.actor.strategy == "megatron": - if self.engine._is_offload_param: - self.load_model_to_gpu() - - params_generator = self._get_actor_params_generator() - - for key, tensor in params_generator: - ret.append((key, tensor.size(), tensor.dtype)) - - else: - raise NotImplementedError(f"Unsupported strategy: {self.config.actor_rollout_ref.actor.strategy}") - - self._weights_info = ret - - return ret - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_model_to_cpu(self, n): + """ + Save the current model state to CPU memory. + + Args: + n: Identifier/Key for the saved model state. + """ if not hasattr(self, "cpu_saved_models"): self.cpu_saved_models = {} @@ -367,8 +102,14 @@ def save_model_to_cpu(self, n): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def restore_model_from_cpu(self, n): + """ + Restore the model state from CPU memory. + + Args: + n: Identifier/Key for the saved model state to restore. + """ if n in self.cpu_saved_models: - strategy = self.config.actor_rollout_ref.actor.strategy + strategy = self.config.actor.strategy if strategy in ["fsdp", "fsdp2"]: cpu_sharded_state, global_spec = self.cpu_saved_models[n] @@ -378,16 +119,11 @@ def restore_model_from_cpu(self, n): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def clear_cpu_model(self, n): + """ + Clear the saved model state from CPU memory. + + Args: + n: Identifier/Key for the saved model state to remove. + """ if n in self.cpu_saved_models: del self.cpu_saved_models[n] - - -class DetachAsyncRolloutWorker(DetachNcclSync): - def __init__(self, config: DictConfig, role: str): - print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") - DetachNcclSync.__init__(self, config, role) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def set_actor_weights_info(self, weights_info): - assert self._is_rollout - self._weights_info = weights_info diff --git a/verl/experimental/separation/ray_trainer.py b/verl/experimental/separation/ray_trainer.py index 56945445f6e..ca850b0590d 100644 --- a/verl/experimental/separation/ray_trainer.py +++ b/verl/experimental/separation/ray_trainer.py @@ -103,6 +103,7 @@ def __init__( # reward message self.reward_tensor = None self.reward_extra_infos_dict = {} + self.checkpoint_manager = None def init_workers(self): """Initialize distributed training workers using Ray backend. @@ -119,7 +120,7 @@ def init_workers(self): self._init_async_rollout_manager() self.checkpoint_manager = CheckpointEngineManager( - backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend, + config=omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine), trainer=self.actor_rollout_wg, replicas=self.async_rollout_manager.rollout_replicas, ) @@ -404,15 +405,12 @@ def _fit_generate(self, batch: DataProto = None) -> DataProto: gen_batch_output = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) with marked_timer("gen", timing_raw, color="red"): - if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) - else: - if self.curr_step_profile: - self.async_rollout_manager.start_profile(global_step=self.global_steps) - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) - self.checkpoint_manager.sleep_replicas() - if self.curr_step_profile: - self.async_rollout_manager.stop_profile() + if self.curr_step_profile: + self.async_rollout_manager.start_profile(global_step=self.global_steps) + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + self.checkpoint_manager.sleep_replicas() + if self.curr_step_profile: + self.async_rollout_manager.stop_profile() timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) @@ -421,15 +419,12 @@ def _fit_generate(self, batch: DataProto = None) -> DataProto: with marked_timer("gen_max", timing_raw, color="purple"): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False - if not self.async_rollout_mode: - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - else: - if self.curr_step_profile: - self.async_rollout_manager.start_profile() - gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) - self.checkpoint_manager.sleep_replicas() - if self.curr_step_profile: - self.async_rollout_manager.stop_profile() + if self.curr_step_profile: + self.async_rollout_manager.start_profile() + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + self.checkpoint_manager.sleep_replicas() + if self.curr_step_profile: + self.async_rollout_manager.stop_profile() batch = batch.union(gen_baseline_output) # compute reward model score on batch rm_scores = None diff --git a/verl/experimental/transfer_queue/agent_loop.py b/verl/experimental/transfer_queue/agent_loop.py deleted file mode 100644 index 308be43ebc0..00000000000 --- a/verl/experimental/transfer_queue/agent_loop.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio - -import numpy as np -import ray -from transfer_queue import BatchMeta - -import verl.experimental.agent_loop.agent_loop as agent_loop - - -class AgentLoopManager(agent_loop.AgentLoopManager): - def generate_sequences(self, prompts: BatchMeta) -> BatchMeta: - """Split input batch and dispatch to agent loop workers. - - Args: - prompts (BatchMeta): Input batch. - - Returns: - BatchMeta: Output batch metadata. - """ - - chunkes = prompts.chunk(len(self.agent_loop_workers)) - outputs = ray.get( - [ - worker.generate_sequences.remote(chunk) - for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) - ] - ) - output = BatchMeta.concat(outputs) - - # calculate performance metrics - metrics = [output.extra_info.pop("metrics") for output in outputs] # List[List[Dict[str, str]]] - timing = self._performance_metrics(metrics, output) - - output.set_extra_info("timing", timing) - return output - - def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: BatchMeta) -> dict[str, float]: - timing = {} - t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk]) - t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk]) - timing["agent_loop/generate_sequences/min"] = t_generate_sequences.min() - timing["agent_loop/generate_sequences/max"] = t_generate_sequences.max() - timing["agent_loop/generate_sequences/mean"] = t_generate_sequences.mean() - timing["agent_loop/tool_calls/min"] = t_tool_calls.min() - timing["agent_loop/tool_calls/max"] = t_tool_calls.max() - timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean() - - # TODO (TQ): initialize tq during init when enable TQ switch is stable - tq_client = self._create_transferqueue_client() - # batch sequence generation is bounded by the slowest sample - slowest = np.argmax(t_generate_sequences + t_tool_calls) - attention_mask = asyncio.run(tq_client.async_get_data(output[slowest]))["attention_mask"] - prompt_length = output.samples[0].fields["prompts"].shape[0] - timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest] - timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest] - timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item() - timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item() - - return timing - - def create_transferqueue_client_for_workers(self): - # TODO (TQ): initialize tq during worker init when enable TQ switch is stable - ray.get([worker.create_transferqueue_client.remote() for worker in self.agent_loop_workers]) - - def _create_transferqueue_client(self): - """Create a client for data system (TransferQueue).""" - from verl.single_controller.ray.base import get_random_string - from verl.utils.transferqueue_utils import create_transferqueue_client - - client_name = get_random_string(length=6) - - tq_client = create_transferqueue_client( - client_id=f"AgentLoopManager_{client_name}", - config=self.config.transfer_queue, - ) - - return tq_client diff --git a/verl/experimental/transfer_queue/config/transfer_queue_ppo_megatron_trainer.yaml b/verl/experimental/transfer_queue/config/transfer_queue_ppo_megatron_trainer.yaml deleted file mode 100644 index 37b19b45708..00000000000 --- a/verl/experimental/transfer_queue/config/transfer_queue_ppo_megatron_trainer.yaml +++ /dev/null @@ -1,14 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_megatron_trainer - - _self_ - -# config for TransferQueue -transfer_queue: - enable: True - num_global_batch: 1 - storage_backend: AsyncSimpleStorageManager - num_data_storage_units: 8 \ No newline at end of file diff --git a/verl/experimental/transfer_queue/config/transfer_queue_ppo_trainer.yaml b/verl/experimental/transfer_queue/config/transfer_queue_ppo_trainer.yaml deleted file mode 100644 index 7a5f57ddd4f..00000000000 --- a/verl/experimental/transfer_queue/config/transfer_queue_ppo_trainer.yaml +++ /dev/null @@ -1,14 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -# config for TransferQueue -transfer_queue: - enable: True - num_global_batch: 1 - storage_backend: AsyncSimpleStorageManager - num_data_storage_units: 8 diff --git a/verl/experimental/transfer_queue/main_ppo.py b/verl/experimental/transfer_queue/main_ppo.py deleted file mode 100644 index 075a8f7ace6..00000000000 --- a/verl/experimental/transfer_queue/main_ppo.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. -""" - -import os -import socket - -import hydra -import ray -from omegaconf import OmegaConf - -from verl.trainer.constants_ppo import get_ppo_ray_runtime_env -from verl.trainer.main_ppo import TaskRunner as MainTaskRunner -from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler -from verl.trainer.ppo.utils import need_critic, need_reference_policy -from verl.utils.config import validate_config -from verl.utils.device import auto_set_device, is_cuda_available - -from .ray_trainer import RayPPOTrainer - - -@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) -def main(config): - """Main entry point for PPO training with Hydra configuration management. - - Args: - config_dict: Hydra configuration dictionary containing training parameters. - """ - # Automatically set `config.trainer.device = npu` when running on Ascend NPU. - auto_set_device(config) - - run_ppo(config) - - -# Define a function to run the PPO-like training process -def run_ppo(config, task_runner_class=None) -> None: - """Initialize Ray cluster and run distributed PPO training process. - - Args: - config: Training configuration object containing all necessary parameters - for distributed PPO training including Ray initialization settings, - model paths, and training hyperparameters. - task_runner_class: For recipe to change TaskRunner. - """ - # Check if Ray is not initialized - if not ray.is_initialized(): - # Initialize Ray with a local cluster configuration - # Set environment variables in the runtime environment to control tokenizer parallelism, - # NCCL debug level, VLLM logging level, and allow runtime LoRA updating - # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration - default_runtime_env = get_ppo_ray_runtime_env() - ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) - runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) - - if config.transfer_queue.enable: - # Add runtime environment variables for transfer queue - runtime_env_vars = runtime_env_kwargs.get("env_vars", {}) - runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1" - runtime_env_kwargs["env_vars"] = runtime_env_vars - - runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) - ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) - print(f"ray init kwargs: {ray_init_kwargs}") - ray.init(**OmegaConf.to_container(ray_init_kwargs)) - - if task_runner_class is None: - task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head - - # Create a remote instance of the TaskRunner class, and - # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete - if ( - is_cuda_available - and config.global_profiler.tool == "nsys" - and config.global_profiler.get("steps") is not None - and len(config.global_profiler.get("steps", [])) > 0 - ): - from verl.utils.import_utils import is_nvtx_available - - assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" - nsight_options = OmegaConf.to_container( - config.global_profiler.global_tool_config.nsys.controller_nsight_options - ) - runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote() - else: - runner = task_runner_class.remote() - ray.get(runner.run.remote(config)) - - # [Optional] get the path of the timeline trace file from the configuration, default to None - # This file is used for performance analysis - timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) - if timeline_json_file: - ray.timeline(filename=timeline_json_file) - - -class TaskRunner(MainTaskRunner): - def run(self, config): - """Execute the main PPO training workflow. - - This method sets up the distributed training environment, initializes - workers, datasets, and reward functions, then starts the training process. - - Args: - config: Training configuration object containing all parameters needed - for setting up and running the PPO training process. - """ - # Print the initial configuration. `resolve=True` will evaluate symbolic values. - from pprint import pprint - - from verl.utils.fs import copy_to_local - - print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") - pprint(OmegaConf.to_container(config, resolve=True)) - OmegaConf.resolve(config) - - actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) - self.add_critic_worker(config) - - # We should adopt a multi-source reward function here: - # - for rule-based rm, we directly call a reward score - # - for model-based rm, we call a model - # - for code related prompt, we send to a sandbox if there are test cases - # finally, we combine all the rewards together - # The reward type depends on the tag of the data - self.add_reward_model_resource_pool(config) - - # Add a reference policy worker if KL loss or KL reward is used. - self.add_ref_policy_worker(config, actor_rollout_cls) - - # validate config - validate_config( - config=config, - use_reference_policy=need_reference_policy(config), - use_critic=need_critic(config), - ) - - # Download the checkpoint from HDFS to the local machine. - # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on - local_path = copy_to_local( - config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) - ) - - # Instantiate the tokenizer and processor. - from verl.utils import hf_processor, hf_tokenizer - - trust_remote_code = config.data.get("trust_remote_code", False) - tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - # Used for multimodal LLM, could be None - processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) - - resource_pool_manager = self.init_resource_pool_mgr(config) - - from verl.utils.dataset.rl_dataset import collate_fn - - # Create training and validation datasets. - train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) - val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) - train_sampler = create_rl_sampler(config.data, train_dataset) - - # Initialize the PPO trainer. - trainer = RayPPOTrainer( - config=config, - tokenizer=tokenizer, - processor=processor, - role_worker_mapping=self.role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - train_dataset=train_dataset, - val_dataset=val_dataset, - collate_fn=collate_fn, - train_sampler=train_sampler, - ) - # Initialize the workers of the trainer. - trainer.init_workers() - # Start the training process. - trainer.fit() - - -if __name__ == "__main__": - main() diff --git a/verl/experimental/transfer_queue/ray_trainer.py b/verl/experimental/transfer_queue/ray_trainer.py deleted file mode 100644 index 3d7c0f58d30..00000000000 --- a/verl/experimental/transfer_queue/ray_trainer.py +++ /dev/null @@ -1,1614 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -PPO Trainer with Ray-based single controller. -This trainer supports model-agonistic model initialization with huggingface -""" - -import json -import logging -import math -import os -import uuid -from collections import defaultdict -from pprint import pprint -from typing import Any, Optional - -import numpy as np -import tensordict -import torch -from omegaconf import OmegaConf, open_dict -from packaging.version import parse as parse_version -from tensordict import TensorDict -from torch.utils.data import Dataset, Sampler -from torchdata.stateful_dataloader import StatefulDataLoader -from tqdm import tqdm -from transfer_queue import ( - BatchMeta, - SimpleStorageUnit, - TransferQueueController, - get_placement_group, - process_zmq_server_info, -) - -from verl import DataProto -from verl.checkpoint_engine import CheckpointEngineManager -from verl.experimental.dataset.sampler import AbstractCurriculumSampler -from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager -from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.config import AlgoConfig -from verl.trainer.ppo import core_algos -from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss -from verl.trainer.ppo.metric_utils import ( - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, - process_validation_metrics, -) -from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model -from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi -from verl.utils.config import omega_conf_to_dataclass -from verl.utils.debug import marked_timer -from verl.utils.metric import reduce_metrics -from verl.utils.rollout_skip import RolloutSkip -from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance -from verl.utils.torch_functional import masked_mean -from verl.utils.tracking import ValidationGenerationsLogger -from verl.utils.transferqueue_utils import create_transferqueue_client, get_transferqueue_client, tqbridge - - -@tqbridge(put_data=False) -def compute_reward_decorated(data): - reward_tensor = data.batch["rm_scores"] - reward_extra_keys = data.meta_info.get("reward_extra_keys", []) - reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys} - return reward_tensor, reward_extra_info - - -@tqbridge(put_data=False) -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): - """Apply KL penalty to the token-level rewards. - - This function computes the KL divergence between the reference policy and current policy, - then applies a penalty to the token-level rewards based on this divergence. - - Args: - data (DataProto): The data containing batched model outputs and inputs. - kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty. - kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl". - - Returns: - tuple: A tuple containing: - - The updated data with token-level rewards adjusted by KL penalty - - A dictionary of metrics related to the KL penalty - """ - response_mask = data.batch["response_mask"] - token_level_scores = data.batch["token_level_scores"] - batch_size = data.batch.batch_size[0] - - # compute kl between ref_policy and current policy - # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. - kld = core_algos.kl_penalty( - data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty - ) # (batch_size, response_length) - kld = kld * response_mask - beta = kl_ctrl.value - - token_level_rewards = token_level_scores - beta * kld - - current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence - current_kl = torch.mean(current_kl, dim=0).item() - - # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 - kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) - - metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} - - return token_level_rewards, metrics - - -def compute_response_mask(batch_meta: BatchMeta, tq_client): - """Compute the attention mask for the response part of the sequence. - - This function extracts the portion of the attention mask that corresponds to the model's response, - which is used for masking computations that should only apply to response tokens. - - Args: - batch_meta (BatchMeta): The data containing batched model outputs and inputs. - - Returns: - BatchMeta: The BatchMeta of attention mask for the response tokens. - """ - data = tq_client.get_data(batch_meta) - - responses = data["responses"] - response_length = responses.size(1) - attention_mask = data["attention_mask"] - response_mask = attention_mask[:, -response_length:] - output = TensorDict({"response_mask": response_mask}, batch_size=response_mask.size(0)) - - batch_meta = tq_client.put(data=output, metadata=batch_meta) - - return batch_meta - - -@tqbridge(put_data=False) -def compute_advantage( - data: DataProto, - adv_estimator: AdvantageEstimator, - gamma: float = 1.0, - lam: float = 1.0, - num_repeat: int = 1, - norm_adv_by_std_in_grpo: bool = True, - config: Optional[AlgoConfig] = None, -) -> tuple[Any, Any]: - """Compute advantage estimates for policy optimization. - - This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. - The advantage estimates are used to guide policy optimization in RL algorithms. - - Args: - data (DataProto): The data containing batched model outputs and inputs. - adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). - gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. - lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. - num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. - norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in - GRPO. Defaults to True. - config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None. - - Returns: - tuple: A tuple containing: - - advantages: The computed advantage estimates. - - returns: The computed returns. - """ - # prepare response group - if adv_estimator == AdvantageEstimator.GAE: - # Compute advantages and returns using Generalized Advantage Estimation (GAE) - advantages, returns = core_algos.compute_gae_advantage_return( - token_level_rewards=data.batch["token_level_rewards"], - values=data.batch["values"], - response_mask=data.batch["response_mask"], - gamma=gamma, - lam=lam, - ) - # TODO (TQ): adapt core_algos.compute_pf_ppo_reweight_data function to support transfer queue - if config.get("use_pf_ppo", False): - data = core_algos.compute_pf_ppo_reweight_data( - data, - config.pf_ppo.get("reweight_method"), - config.pf_ppo.get("weight_pow"), - ) - elif adv_estimator == AdvantageEstimator.GRPO: - # Initialize the mask for GRPO calculation - grpo_calculation_mask = data.batch["response_mask"] - # Call compute_grpo_outcome_advantage with parameters matching its definition - advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=grpo_calculation_mask, - index=data.non_tensor_batch["uid"], - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - ) - else: - # handle all other adv estimator type other than GAE and GRPO - adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) - adv_kwargs = { - "token_level_rewards": data.batch["token_level_rewards"], - "response_mask": data.batch["response_mask"], - "config": config, - } - if "uid" in data.non_tensor_batch: # optional - adv_kwargs["index"] = data.non_tensor_batch["uid"] - if "reward_baselines" in data.batch: # optional - adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] - - # calculate advantage estimator - advantages, returns = adv_estimator_fn(**adv_kwargs) - return advantages, returns - - -@tqbridge(put_data=False) -def compute_data_metrics_decorated(batch, use_critic: bool = True): - return compute_data_metrics(batch, use_critic) - - -@tqbridge(put_data=False) -def compute_timing_metrics_decorated(batch, timing_raw: dict[str, float]) -> dict[str, Any]: - return compute_timing_metrics(batch, timing_raw) - - -@tqbridge(put_data=False) -def compute_throughout_metrics_decorated(batch, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]: - return compute_throughout_metrics(batch, timing_raw, n_gpus) - - -@tqbridge(put_data=False) -def calculate_debug_metrics_decorated(data): - from verl.utils.debug.metrics import calculate_debug_metrics - - return calculate_debug_metrics(data) - - -class RayPPOTrainer: - """Distributed PPO trainer using Ray for scalable reinforcement learning. - - This trainer orchestrates distributed PPO training across multiple nodes and GPUs, - managing actor rollouts, critic training, and reward computation with Ray backend. - Supports various model architectures including FSDP, Megatron, vLLM, and SGLang integration. - """ - - # TODO: support each role have individual ray_worker_group_cls, - # i.e., support different backend of different role - def __init__( - self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup, - processor=None, - train_dataset: Optional[Dataset] = None, - val_dataset: Optional[Dataset] = None, - collate_fn=None, - train_sampler: Optional[Sampler] = None, - device_name=None, - ): - """ - Initialize distributed PPO trainer with Ray backend. - Note that this trainer runs on the driver process on a single CPU/GPU node. - - Args: - config: Configuration object containing training parameters. - tokenizer: Tokenizer used for encoding and decoding text. - role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. - resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. - ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. - processor: Optional data processor, used for multimodal data - train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. - val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. - collate_fn: Function to collate data samples into batches. - train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. - device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. - """ - - # Store the tokenizer for text processing - self.tokenizer = tokenizer - self.processor = processor - self.config = config - - self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, "Currently, only support hybrid engine" - - if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" - - self.role_worker_mapping = role_worker_mapping - self.resource_pool_manager = resource_pool_manager - self.use_reference_policy = need_reference_policy(self.config) - self.use_rm = need_reward_model(self.config) - self.use_critic = need_critic(self.config) - self.ray_worker_group_cls = ray_worker_group_cls - self.device_name = device_name if device_name else self.config.trainer.device - self.validation_generations_logger = ValidationGenerationsLogger( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - ) - - lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0) - if lora_rank <= 0: - lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) - # if ref_in_actor is True, the reference policy will be actor without lora applied - self.ref_in_actor = lora_rank > 0 - - # define in-reward KL control - # kl loss control currently not suppoorted - if self.config.algorithm.use_kl_in_reward: - self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) - - self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) - - self.tq_client = self._initialize_transferqueue() - - def _initialize_transferqueue(self): - # 1. initialize TransferQueueStorage - if self.config.transfer_queue.storage_backend == "AsyncSimpleStorageManager": - train_data_size = ( - self.config.data.train_batch_size - * self.config.transfer_queue.num_global_batch - * self.config.actor_rollout_ref.rollout.n - ) - val_data_size = self.val_dataset_size * self.config.actor_rollout_ref.rollout.val_kwargs.n - - total_storage_size = train_data_size + val_data_size - self.data_system_storage_units = {} - storage_placement_group = get_placement_group( - self.config.transfer_queue.num_data_storage_units, num_cpus_per_actor=1 - ) - for storage_unit_rank in range(self.config.transfer_queue.num_data_storage_units): - storage_node = SimpleStorageUnit.options( - placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote( - storage_unit_size=math.ceil(total_storage_size / self.config.transfer_queue.num_data_storage_units) - ) - self.data_system_storage_units[storage_unit_rank] = storage_node - logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.") - else: - raise NotImplementedError("Currently only support AsyncSimpleStorageManager backend in TransferQueue") - - # 2. Initialize TransferQueueController (single controller only) - - # Sampler usage instructions: - # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler: - # Option 1: Pass sampler class (will be instantiated automatically) - # self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler) - - # Option 2: Pass sampler instance (if you need custom configuration) - # grpo_sampler = GRPOGroupNSampler() - # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler) - - # Then use sampling_config in get_meta calls: - # sampling_config={"n_samples_per_prompt": 4} - self.data_system_controller = TransferQueueController.remote() - logging.info("TransferQueueController has been created.") - - # 3. register controller & storage and prepare necessary information - self.data_system_controller_info = process_zmq_server_info(self.data_system_controller) - if self.config.transfer_queue.storage_backend == "AsyncSimpleStorageManager": - self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) - - # Note: Need to generate a new DictConfig with allow_objects=True to preserve ZMQServerInfo instances - # (which contain socket connection details). Without this flag, OmegaConf would flatten these objects to dicts, - # breaking the transfer queue client initialization. - tq_config = OmegaConf.create({"transfer_queue": {}}, flags={"allow_objects": True}) - tq_config.transfer_queue.controller_info = self.data_system_controller_info - - if self.config.transfer_queue.storage_backend == "AsyncSimpleStorageManager": - tq_config.transfer_queue.storage_unit_infos = self.data_system_storage_unit_infos - - self.config = OmegaConf.merge(tq_config, self.config) - - # 4. create client - create_transferqueue_client(client_id="Trainer", config=self.config.transfer_queue, sync=True) - tq_client = get_transferqueue_client() - return tq_client - - def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): - """ - Creates the train and validation dataloaders. - """ - # TODO: we have to make sure the batch size is divisible by the dp size - from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler - - if train_dataset is None: - train_dataset = create_rl_dataset( - self.config.data.train_files, self.config.data, self.tokenizer, self.processor - ) - if val_dataset is None: - val_dataset = create_rl_dataset( - self.config.data.val_files, self.config.data, self.tokenizer, self.processor - ) - self.train_dataset, self.val_dataset = train_dataset, val_dataset - - self.val_dataset_size = len(val_dataset) - - if train_sampler is None: - train_sampler = create_rl_sampler(self.config.data, self.train_dataset) - if collate_fn is None: - from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn - - collate_fn = default_collate_fn - - num_workers = self.config.data["dataloader_num_workers"] - - self.train_dataloader = StatefulDataLoader( - dataset=self.train_dataset, - batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), - num_workers=num_workers, - drop_last=True, - collate_fn=collate_fn, - sampler=train_sampler, - ) - - val_batch_size = self.config.data.val_batch_size # Prefer config value if set - if val_batch_size is None: - val_batch_size = len(self.val_dataset) - self.val_batch_size = val_batch_size - - self.val_dataloader = StatefulDataLoader( - dataset=self.val_dataset, - batch_size=val_batch_size, - num_workers=num_workers, - shuffle=self.config.data.get("validation_shuffle", True), - drop_last=False, - collate_fn=collate_fn, - ) - - assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" - assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" - - print( - f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " - f"{len(self.val_dataloader)}" - ) - - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs - - if self.config.trainer.total_training_steps is not None: - total_training_steps = self.config.trainer.total_training_steps - - self.total_training_steps = total_training_steps - print(f"Total training steps: {self.total_training_steps}") - - try: - OmegaConf.set_struct(self.config, True) - with open_dict(self.config): - if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps - if OmegaConf.select(self.config, "critic.optim"): - self.config.critic.optim.total_training_steps = total_training_steps - except Exception as e: - print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") - - def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): - """Dump rollout/validation samples as JSONL.""" - os.makedirs(dump_path, exist_ok=True) - filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") - - n = len(inputs) - base_data = { - "input": inputs, - "output": outputs, - "gts": gts, - "score": scores, - "step": [self.global_steps] * n, - } - - for k, v in reward_extra_infos_dict.items(): - if len(v) == n: - base_data[k] = v - - lines = [] - for i in range(n): - entry = {k: v[i] for k, v in base_data.items()} - lines.append(json.dumps(entry, ensure_ascii=False)) - - with open(filename, "w") as f: - f.write("\n".join(lines) + "\n") - - print(f"Dumped generations to {filename}") - - def _log_rollout_data( - self, log_rollout_meta: BatchMeta, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str - ): - """ - Log rollout data to disk. - - Args: - log_rollout_meta (BatchMeta): The batch_meta of rollout data - reward_extra_infos_dict (dict): Additional reward information to log - timing_raw (dict): Timing information for profiling - rollout_data_dir (str): Directory path to save the rollout data - """ - with marked_timer("dump_rollout_generations", timing_raw, color="green"): - data = self.tq_client.get_data(log_rollout_meta) - - inputs = self.tokenizer.batch_decode(data["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(data["responses"], skip_special_tokens=True) - scores = data["token_level_scores"].sum(-1).cpu().tolist() - sample_gts = [item.get("ground_truth", None) for item in data.get("reward_model", {})] - - reward_extra_infos_to_dump = reward_extra_infos_dict.copy() - if "request_id" in log_rollout_meta.field_names: - reward_extra_infos_dict.setdefault( - "request_id", - data["request_id"].tolist(), - ) - - self._dump_generations( - inputs=inputs, - outputs=outputs, - gts=sample_gts, - scores=scores, - reward_extra_infos_dict=reward_extra_infos_to_dump, - dump_path=rollout_data_dir, - ) - - def _maybe_log_val_generations(self, inputs, outputs, scores): - """Log a table of validation samples to the configured logger (wandb or swanlab)""" - - generations_to_log = self.config.trainer.log_val_generations - - if generations_to_log == 0: - return - - import numpy as np - - # Create tuples of (input, output, score) and sort by input text - samples = list(zip(inputs, outputs, scores, strict=True)) - samples.sort(key=lambda x: x[0]) # Sort by input text - - # Use fixed random seed for deterministic shuffling - rng = np.random.RandomState(42) - rng.shuffle(samples) - - # Take first N samples after shuffling - samples = samples[:generations_to_log] - - # Log to each configured logger - self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) - - def _get_gen_batch(self, batch: DataProto) -> DataProto: - reward_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() - - # pop those keys for generation - batch_keys_to_pop = [] - non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_keys - gen_batch = batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), - ) - - # For agent loop, we need reward model keys to compute score. - if self.async_rollout_mode: - gen_batch.non_tensor_batch.update(batch.non_tensor_batch) - - return gen_batch - - def _validate(self): - data_source_lst = [] - reward_extra_infos_dict: dict[str, list] = defaultdict(list) - - # Lists to collect samples for the table - sample_inputs = [] - sample_outputs = [] - sample_gts = [] - sample_scores = [] - sample_turns = [] - sample_uids = [] - - for test_data in self.val_dataloader: - if "uid" not in test_data.keys(): - test_data["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(test_data["raw_prompt"]))], dtype=object - ) - - # repeat test data - repeated_test_data = self.repeat_dict( - test_data, repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True - ) - - test_batch: TensorDict = self.dict_to_tensordict(repeated_test_data) - - # we only do validation on rule-based rm - if self.config.reward.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model": - return {} - - batch_meta = self.tq_client.put(data=test_batch, partition_id=f"val_{self.global_steps - 1}") - - batch_meta.update_extra_info( - { - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - "recompute_log_prob": False, - "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, - "validate": True, - "global_steps": self.global_steps, - } - ) - print(f"batch_meta extra_info: {batch_meta.extra_info}") - - # TODO: (TQ) Support padding and unpadding to make DataProto divisible by dp_size with TransferQueue - if not self.async_rollout_mode: - test_output_gen_meta = self.actor_rollout_wg.generate_sequences(batch_meta) - else: - test_output_gen_meta = self.async_rollout_manager.generate_sequences(batch_meta) - - batch_meta = batch_meta.union(test_output_gen_meta) - - print("validation generation end") - - # Store generated outputs - test_response_meta = batch_meta.select_fields(["prompts", "responses", "uid", "reward_model"]) - data = self.tq_client.get_data(test_response_meta) - output_ids = data["responses"] - output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] - sample_outputs.extend(output_texts) - - # TODO: Can we keep special tokens except for padding tokens? - input_ids = data["prompts"] - input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] - sample_inputs.extend(input_texts) - sample_uids.extend(data["uid"]) - - ground_truths = [item.get("ground_truth", None) for item in data.get("reward_model", {})] - sample_gts.extend(ground_truths) - - reward_tensor, reward_extra_info = compute_reward_decorated(batch_meta) - - scores = reward_tensor.sum(-1).cpu().tolist() - sample_scores.extend(scores) - - reward_extra_infos_dict["reward"].extend(scores) - print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") - for key, lst in reward_extra_info.items(): - reward_extra_infos_dict[key].extend(lst) - print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") - - # collect num_turns of each prompt - if "__num_turns__" in batch_meta.field_names: - data = self.tq_client.get_data(batch_meta.select_fields(["__num_turns__"])) - sample_turns.append(data["__num_turns__"]) - - data_source = ["unknown"] * reward_tensor.shape[0] - if "data_source" in batch_meta.field_names: - data_source_meta = batch_meta.select_fields(["data_source"]) - data = self.tq_client.get_data(data_source_meta) - data_source = data["data_source"] - - data_source_lst.append(data_source) - - self.tq_client.clear_samples(batch_meta) - - self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) - - # dump generations - val_data_dir = self.config.trainer.get("validation_data_dir", None) - if val_data_dir: - self._dump_generations( - inputs=sample_inputs, - outputs=sample_outputs, - gts=sample_gts, - scores=sample_scores, - reward_extra_infos_dict=reward_extra_infos_dict, - dump_path=val_data_dir, - ) - - for key_info, lst in reward_extra_infos_dict.items(): - assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" - - data_sources = np.concatenate(data_source_lst, axis=0) - - data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) - metric_dict = {} - for data_source, var2metric2val in data_src2var2metric2val.items(): - core_var = "acc" if "acc" in var2metric2val else "reward" - for var_name, metric2val in var2metric2val.items(): - n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) - for metric_name, metric_val in metric2val.items(): - if ( - (var_name == core_var) - and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) - and (f"@{n_max}" in metric_name) - ): - metric_sec = "val-core" - else: - metric_sec = "val-aux" - pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" - metric_dict[pfx] = metric_val - - if len(sample_turns) > 0: - sample_turns = np.concatenate(sample_turns) - metric_dict["val-aux/num_turns/min"] = sample_turns.min() - metric_dict["val-aux/num_turns/max"] = sample_turns.max() - metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() - - return metric_dict - - def init_workers(self): - """Initialize distributed training workers using Ray backend. - - Creates: - 1. Ray resource pools from configuration - 2. Worker groups for each role (actor, critic, etc.) - """ - self.resource_pool_manager.create_resource_pool() - - self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} - - # create actor and rollout - if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role="actor_rollout", - ) - self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls - else: - raise NotImplementedError - - # create critic - if self.use_critic: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cfg = omega_conf_to_dataclass(self.config.critic) - critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) - self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls - - # create reference policy if needed - if self.use_reference_policy: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RefPolicy], - config=self.config.actor_rollout_ref, - role="ref", - ) - self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls - - # initialize WorkerGroup - # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, - # you should not use `create_colocated_worker_cls`. - # Instead, directly pass different resource pool to different worker groups. - # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. - all_wg = {} - wg_kwargs = {} # Setting up kwargs for RayWorkerGroup - if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: - wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout - if OmegaConf.select(self.config.global_profiler, "steps") is not None: - wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") - # Only require nsight worker options when tool is nsys - if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": - assert ( - OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") - is not None - ), "worker_nsight_options must be set when using nsys with profile_steps" - wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( - OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") - ) - wg_kwargs["device_name"] = self.device_name - - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls( - resource_pool=resource_pool, - ray_cls_with_init=worker_dict_cls, - **wg_kwargs, - ) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - - if self.use_critic: - self.critic_wg = all_wg["critic"] - self.critic_wg.init_model() - - if self.use_reference_policy and not self.ref_in_actor: - self.ref_policy_wg = all_wg["ref"] - self.ref_policy_wg.init_model() - - # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg["actor_rollout"] - self.actor_rollout_wg.init_model() - - # set transferqueue server info for each worker - for _, wg in all_wg.items(): - wg.create_transferqueue_client(self.config) - - # create reward loop manager - from verl.experimental.reward_loop import RewardLoopManager - - # initalize reward loop manager - # reward model (colocate or standalone): get resource_pool - # no reward model: resource_pool = None - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) if self.use_rm else None - self.reward_loop_manager = RewardLoopManager( - config=self.config, - rm_resource_pool=resource_pool, - ) - - # create async rollout manager and request scheduler - self.async_rollout_mode = False - if self.config.actor_rollout_ref.rollout.mode == "async": - from .agent_loop import AgentLoopManager - - self.async_rollout_mode = True - - enable_agent_reward_loop = not self.use_rm or self.config.reward.reward_model.enable_resource_pool - - reward_loop_worker_handles = ( - self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None - ) - self.async_rollout_manager = AgentLoopManager( - config=self.config, - worker_group=self.actor_rollout_wg, - reward_loop_worker_handles=reward_loop_worker_handles, - ) - - self.checkpoint_manager = CheckpointEngineManager( - backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend, - trainer=self.actor_rollout_wg, - replicas=self.async_rollout_manager.rollout_replicas, - ) - - # sleep all replicas to load checkpoint - self.checkpoint_manager.sleep_replicas() - - # TODO (TQ): initialize tq during worker init when enable TQ switch is stable - self.async_rollout_manager.create_transferqueue_client_for_workers() - - def _save_checkpoint(self): - from verl.utils.fs import local_mkdir_safe - - # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join( - self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" - ) - - print(f"local_global_step_folder: {local_global_step_folder}") - actor_local_path = os.path.join(local_global_step_folder, "actor") - - actor_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") - ) - - remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) - if remove_previous_ckpt_in_save: - print( - "Warning: remove_previous_ckpt_in_save is deprecated," - + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" - ) - max_actor_ckpt_to_keep = ( - self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 - ) - max_critic_ckpt_to_keep = ( - self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 - ) - - self.actor_rollout_wg.save_checkpoint( - actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep - ) - - if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, "critic") - critic_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") - ) - self.critic_wg.save_checkpoint( - critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep - ) - - # save dataloader - local_mkdir_safe(local_global_step_folder) - dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") - dataloader_state_dict = self.train_dataloader.state_dict() - torch.save(dataloader_state_dict, dataloader_local_path) - - # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join( - self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" - ) - with open(local_latest_checkpointed_iteration, "w") as f: - f.write(str(self.global_steps)) - - def _load_checkpoint(self): - if self.config.trainer.resume_mode == "disable": - return 0 - - # load from hdfs - if self.config.trainer.default_hdfs_dir is not None: - raise NotImplementedError("load from hdfs is not implemented yet") - else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path - if not os.path.isabs(checkpoint_folder): - working_dir = os.getcwd() - checkpoint_folder = os.path.join(working_dir, checkpoint_folder) - global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest - - # find global_step_folder - if self.config.trainer.resume_mode == "auto": - if global_step_folder is None: - print("Training from scratch") - return 0 - else: - if self.config.trainer.resume_mode == "resume_path": - assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert "global_step_" in self.config.trainer.resume_from_path, ( - "resume ckpt must specify the global_steps" - ) - global_step_folder = self.config.trainer.resume_from_path - if not os.path.isabs(global_step_folder): - working_dir = os.getcwd() - global_step_folder = os.path.join(working_dir, global_step_folder) - print(f"Load from checkpoint folder: {global_step_folder}") - # set global step - self.global_steps = int(global_step_folder.split("global_step_")[-1]) - - print(f"Setting global step to {self.global_steps}") - print(f"Resuming from {global_step_folder}") - - actor_path = os.path.join(global_step_folder, "actor") - critic_path = os.path.join(global_step_folder, "critic") - # load actor - self.actor_rollout_wg.load_checkpoint( - actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - # load critic - if self.use_critic: - self.critic_wg.load_checkpoint( - critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - - # load dataloader, - # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, "data.pt") - if os.path.exists(dataloader_local_path): - dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) - self.train_dataloader.load_state_dict(dataloader_state_dict) - else: - print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") - - def _start_profiling(self, do_profile: bool) -> None: - """Start profiling for all worker groups if profiling is enabled.""" - if do_profile: - self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) - if self.use_reference_policy: - self.ref_policy_wg.start_profile(profile_step=self.global_steps) - if self.use_critic: - self.critic_wg.start_profile(profile_step=self.global_steps) - - def _stop_profiling(self, do_profile: bool) -> None: - """Stop profiling for all worker groups if profiling is enabled.""" - if do_profile: - self.actor_rollout_wg.stop_profile() - if self.use_reference_policy: - self.ref_policy_wg.stop_profile() - if self.use_critic: - self.critic_wg.stop_profile() - - def _balance_batch( - self, batch: BatchMeta, tq_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False - ): - """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens""" - data = tq_client.get_data(batch) - - attention_mask = data["attention_mask"] - batch_size = attention_mask.shape[0] - global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) - global_seqlen_lst = calculate_workload(global_seqlen_lst) - world_size = self.actor_rollout_wg.world_size - if keep_minibatch: - # Decouple the DP balancing and mini-batching. - minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size", None) - if minibatch_size is None: - raise ValueError("'ppo_mini_batch_size' must be set in actor config when 'keep_minibatch' is True.") - minibatch_num = len(global_seqlen_lst) // minibatch_size - global_partition_lst = [[] for _ in range(world_size)] - for i in range(minibatch_num): - rearrange_minibatch_lst = get_seqlen_balanced_partitions( - global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size], - k_partitions=world_size, - equal_size=True, - ) - for j, part in enumerate(rearrange_minibatch_lst): - global_partition_lst[j].extend([x + minibatch_size * i for x in part]) - else: - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) - # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. - for idx, partition in enumerate(global_partition_lst): - partition.sort(key=lambda x: (global_seqlen_lst[x], x)) - ordered_partition = partition[::2] + partition[1::2][::-1] - global_partition_lst[idx] = ordered_partition - # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = [j for partition in global_partition_lst for j in partition] - global_balance_stats = log_seqlen_unbalance( - seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix - ) - metrics.update(global_balance_stats) - return global_idx - - @classmethod - def repeat_dict( - cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_times=2, interleave=True - ) -> dict[str, torch.Tensor | np.ndarray]: - """ - Repeat the batch dict a specified number of times. - - Args: - repeat_times (int): Number of times to repeat the data. - interleave (bool): Whether to interleave the repeated data. - - Returns: - dict: A new dict with repeated data. - """ - if repeat_times == 1: - return batch_dict - - repeated_batch_dict = {} - if batch_dict: - if interleave: - # Interleave the data - for key, val in batch_dict.items(): - if isinstance(val, torch.Tensor): - repeated_batch_dict[key] = val.repeat_interleave(repeat_times, dim=0) - elif isinstance(val, np.ndarray): - repeated_batch_dict[key] = np.repeat(val, repeat_times, axis=0) - else: - raise ValueError(f"Unsupported type in data {type(val)}") - else: - # Stack the data - for key, val in batch_dict.items(): - if isinstance(val, torch.Tensor): - repeated_batch_dict[key] = ( - val.unsqueeze(0).expand(repeat_times, *val.shape).reshape(-1, *val.shape[1:]) - ) - elif isinstance(val, np.ndarray): - repeated_batch_dict[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) - else: - raise ValueError(f"Unsupported type in data {type(val)}") - return repeated_batch_dict - - @classmethod - def dict_to_tensordict(cls, data: dict[str, torch.Tensor | np.ndarray]) -> TensorDict: - """ - Create a TensorDict from a dict of tensors and non_tensors. - Note that this requires tensordict version at least 0.10 - """ - assert parse_version(tensordict.__version__) >= parse_version("0.10"), ( - "Storing non-tensor data in TensorDict at least requires tensordict version 0.10" - ) - tensors_batch = {} - batch_size = None - - for key, val in data.items(): - if isinstance(val, torch.Tensor | np.ndarray): - tensors_batch[key] = val - else: - raise ValueError(f"Unsupported type in data {type(val)}") - - if batch_size is None: - batch_size = len(val) - else: - assert len(val) == batch_size - - if batch_size is None: - batch_size = [] - else: - batch_size = [batch_size] - - return TensorDict(tensors_batch, batch_size=batch_size) - - def fit(self): - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC - to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - - from verl.utils.tracking import Tracking - - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - - self.global_steps = 0 - - # load checkpoint and update weights before doing anything - self._load_checkpoint() - self.checkpoint_manager.update_weights() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - assert val_metrics, f"{val_metrics=}" - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): - rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) - rollout_skip.wrap_generate_sequences() - - # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") - - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - self.max_steps_duration = 0 - - prev_step_profile = False - curr_step_profile = ( - self.global_steps in self.config.global_profiler.steps - if self.config.global_profiler.steps is not None - else False - ) - next_step_profile = False - - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - timing_raw = {} - base_get_meta_kwargs = dict( - batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1 - ) - - with marked_timer("start_profile", timing_raw): - self._start_profiling( - not prev_step_profile and curr_step_profile - if self.config.global_profiler.profile_continuous_steps - else curr_step_profile - ) - - # add uid to batch - batch_dict["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch_dict["raw_prompt"]))], dtype=object - ) - # When n > 1, repeat input data before putting to data system, simulating DataProto repeat. - repeated_batch_dict = self.repeat_dict( - batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True - ) - batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) - gen_meta = self.tq_client.put(data=batch, partition_id=f"train_{self.global_steps - 1}") - - # pass global_steps to trace - gen_meta.set_extra_info("global_steps", self.global_steps) - - is_last_step = self.global_steps >= self.total_training_steps - - with marked_timer("step", timing_raw): - # generate a batch - with marked_timer("gen", timing_raw, color="red"): - if not self.async_rollout_mode: - gen_output_meta = self.actor_rollout_wg.generate_sequences(gen_meta) - else: - gen_output_meta = self.async_rollout_manager.generate_sequences(gen_meta) - self.checkpoint_manager.sleep_replicas() - timing_raw.update(gen_output_meta.extra_info["timing"]) - gen_output_meta.extra_info.pop("timing", None) - - # TODO (TQ): support transfer queue - # if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - # if self.reward_fn is None: - # raise ValueError("A reward_fn is required for REMAX advantage estimation.") - # - # with marked_timer("gen_max", timing_raw, color="purple"): - # gen_baseline_meta = deepcopy(gen_meta) - # gen_baseline_meta.extra_info["do_sample"] = False - # if not self.async_rollout_mode: - # gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_meta) - # else: - # gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_meta) - # batch = batch.union(gen_baseline_output) - # reward_baseline_tensor = self.reward_fn(batch) - # reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - # - # batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - # - # batch.batch["reward_baselines"] = reward_baseline_tensor - # - # del gen_baseline_batch, gen_baseline_output - - batch_meta: BatchMeta = gen_meta.union(gen_output_meta) - - if "response_mask" not in batch_meta.field_names: - response_mask_meta = self.tq_client.get_meta( - data_fields=["responses", "attention_mask"], - task_name="compute_response_mask", - **base_get_meta_kwargs, - ) - response_mask_output_meta = compute_response_mask(response_mask_meta, self.tq_client) - batch_meta = batch_meta.union(response_mask_output_meta) - - # Balance the number of valid tokens across DP ranks. - # NOTE: This usually changes the order of data in the `batch`, - # which won't affect the advantage calculation (since it's based on uid), - # but might affect the loss calculation (due to the change of mini-batching). - # TODO: Decouple the DP balancing and mini-batching. - - attention_mask_meta = batch_meta.select_fields(["attention_mask"]) - balanced_idx = None - if self.config.trainer.balance_batch: - balanced_idx = self._balance_batch(attention_mask_meta, self.tq_client, metrics=metrics) - batch_meta.reorder(balanced_idx) - - # compute global_valid tokens - data = self.tq_client.get_data(attention_mask_meta) - batch_meta.extra_info["global_token_num"] = torch.sum(data["attention_mask"], dim=-1).tolist() - - with marked_timer("reward", timing_raw, color="yellow"): - # compute reward model score - if self.use_rm and "rm_scores" not in batch_meta.field_names: - reward_meta = self.rm_wg.compute_rm_score(batch_meta) - batch_meta = batch_meta.union(reward_meta) - - compute_reward_fields = [ - "responses", - "prompts", - "attention_mask", - "reward_model", - "data_source", - ] - if "rm_scores" in batch_meta.field_names: - compute_reward_fields.extend( - ["rm_scores", *set(batch_meta.extra_info["reward_extra_keys"])] - ) - - reward_tensor, reward_extra_infos_dict = compute_reward_decorated(batch_meta) - - compute_reward_meta = batch_meta.select_fields(compute_reward_fields) - batch_meta = batch_meta.union(compute_reward_meta) - - # recompute old_log_probs - with marked_timer("old_log_prob", timing_raw, color="blue"): - old_log_prob_meta_fields = [ - "input_ids", - "attention_mask", - "position_ids", - "prompts", - "responses", - "response_mask", - "data_source", - "reward_model", - "extra_info", - "uid", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - ] - old_log_prob_meta = batch_meta.select_fields(old_log_prob_meta_fields) - old_log_prob_output_meta = self.actor_rollout_wg.compute_log_prob(old_log_prob_meta) - batch_meta = batch_meta.union(old_log_prob_output_meta) - - data = self.tq_client.get_data(old_log_prob_output_meta) - entropys = data["entropys"] - response_masks = data["response_mask"] - actor_config = self.config.actor_rollout_ref.actor - entropy_agg = agg_loss( - loss_mat=entropys, - loss_mask=response_masks, - loss_agg_mode=actor_config.loss_agg_mode, - loss_scale_factor=actor_config.loss_scale_factor, - ) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} - metrics.update(old_log_prob_metrics) - - if "rollout_log_probs" in batch_meta.field_names: - # TODO: we may want to add diff of probs too. - calculate_debug_metrics_fields = ["rollout_log_probs", "old_log_probs", "responses"] - - if "response_mask" in batch_meta.field_names: - calculate_debug_metrics_fields.append("response_mask") - if "attention_mask" in batch_meta.field_names: - calculate_debug_metrics_fields.append("attention_mask") - - calculate_debug_metrics_meta = batch_meta.select_fields(calculate_debug_metrics_fields) - metrics.update(calculate_debug_metrics_decorated(calculate_debug_metrics_meta)) - - if self.use_reference_policy: - # compute reference log_prob - ref_log_prob_fields = [ - "input_ids", - "attention_mask", - "position_ids", - "prompts", - "responses", - "response_mask", - "old_log_probs", - "data_source", - "reward_model", - "extra_info", - "uid", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - ] - ref_log_prob_meta = batch_meta.select_fields(ref_log_prob_fields) - - with marked_timer("ref", timing_raw, color="olive"): - if not self.ref_in_actor: - ref_log_prob_output_meta = self.ref_policy_wg.compute_ref_log_prob(ref_log_prob_meta) - else: - ref_log_prob_output_meta = self.actor_rollout_wg.compute_ref_log_prob(ref_log_prob_meta) - batch_meta = batch_meta.union(ref_log_prob_output_meta) - - # compute values - if self.use_critic: - with marked_timer("values", timing_raw, color="cyan"): - values_meta = self.critic_wg.compute_values(batch_meta) - batch_meta = batch_meta.union(values_meta) - - with marked_timer("adv", timing_raw, color="brown"): - # we combine with rule-based rm - reward_extra_infos_dict: dict[str, list] - reward_td = TensorDict({"token_level_scores": reward_tensor}, batch_size=reward_tensor.size(0)) - batch_meta = self.tq_client.put(data=reward_td, metadata=batch_meta) - - if reward_extra_infos_dict: - reward_extra_infos_dict_new = {k: np.array(v) for k, v in reward_extra_infos_dict.items()} - reward_extra_infos_td = self.dict_to_tensordict(reward_extra_infos_dict_new) - batch_meta = self.tq_client.put(data=reward_extra_infos_td, metadata=batch_meta) - - # compute rewards. apply_kl_penalty if available - if self.config.algorithm.use_kl_in_reward: - apply_kl_penalty_fields = [ - "response_mask", - "token_level_scores", - "old_log_probs", - "ref_log_prob", - ] - - apply_kl_penalty_meta = batch_meta.select_fields(apply_kl_penalty_fields) - - token_level_rewards, kl_metrics = apply_kl_penalty( - apply_kl_penalty_meta, - kl_ctrl=self.kl_ctrl_in_reward, - kl_penalty=self.config.algorithm.kl_penalty, - ) - token_level_rewards_td = TensorDict( - {"token_level_rewards": token_level_rewards}, batch_size=token_level_rewards.size(0) - ) - apply_kl_penalty_meta = self.tq_client.put( - data=token_level_rewards_td, metadata=apply_kl_penalty_meta - ) - - metrics.update(kl_metrics) - batch_meta = batch_meta.union(apply_kl_penalty_meta) - else: - token_level_scores_meta = batch_meta.select_fields(["token_level_scores"]) - - data = self.tq_client.get_data(token_level_scores_meta) - token_level_rewards_td = TensorDict( - {"token_level_rewards": data["token_level_scores"]}, - batch_size=data["token_level_scores"].size(0), - ) - token_level_scores_meta = self.tq_client.put( - data=token_level_rewards_td, metadata=token_level_scores_meta - ) - batch_meta = batch_meta.union(token_level_scores_meta) - - # compute advantages, executed on the driver process - - norm_adv_by_std_in_grpo = self.config.algorithm.get( - "norm_adv_by_std_in_grpo", True - ) # GRPO adv normalization factor - - assert "response_mask" in batch_meta.field_names, ( - f"`response_mask` must be in batch_meta {batch_meta.field_names} for advantage computation" - ) - compute_advantage_fields = [ - "response_mask", - "token_level_rewards", - ] - if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: - compute_advantage_fields.append("values") - elif self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO: - compute_advantage_fields.append("uid") - else: - if "uid" in batch_meta.field_names: - compute_advantage_fields.append("uid") - if "reward_baselines" in batch_meta.field_names: - compute_advantage_fields.append("reward_baselines") - - compute_advantage_meta = batch_meta.select_fields(compute_advantage_fields) - - advantages, returns = compute_advantage( - compute_advantage_meta, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - config=self.config.algorithm, - ) - - advantages_td = TensorDict( - {"advantages": advantages, "returns": returns}, batch_size=advantages.size(0) - ) - compute_advantage_meta = self.tq_client.put(data=advantages_td, metadata=compute_advantage_meta) - batch_meta = batch_meta.union(compute_advantage_meta) - - # update critic - if self.use_critic: - with marked_timer("update_critic", timing_raw, color="pink"): - critic_output_meta = self.critic_wg.update_critic(batch_meta) - batch_meta = batch_meta.union(critic_output_meta) - critic_output_metrics = reduce_metrics(critic_output_meta.extra_info["metrics"]) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with marked_timer("update_actor", timing_raw, color="red"): - batch_meta.extra_info["multi_turn"] = ( - self.config.actor_rollout_ref.rollout.multi_turn.enable - ) - - update_actor_fields = [ - "input_ids", - "attention_mask", - "position_ids", - "prompts", - "responses", - "response_mask", - "old_log_probs", - "ref_log_prob", - "advantages", - "returns", - "token_level_rewards", - "token_level_scores", - "data_source", - "reward_model", - "extra_info", - "uid", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - ] - update_actor_meta = batch_meta.select_fields(update_actor_fields) - - update_actor_meta.set_extra_info( - "global_token_num", batch_meta.get_extra_info("global_token_num") - ) - update_actor_meta.set_extra_info("temperature", batch_meta.get_extra_info("temperature")) - - actor_output_meta = self.actor_rollout_wg.update_actor(update_actor_meta) - batch_meta = batch_meta.union(actor_output_meta) - - # update weights from trainer to rollout - with marked_timer("update_weights", timing_raw, color="red"): - self.checkpoint_manager.update_weights() - - actor_output_metrics = reduce_metrics(actor_output_meta.extra_info["metrics"]) - metrics.update(actor_output_metrics) - - # Log rollout generations if enabled - rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) - if rollout_data_dir: - log_rollout_fields = ["prompts", "responses", "token_level_scores", "reward_model"] - if "request_id" in batch_meta.field_names: - log_rollout_fields.append("request_id") - log_rollout_meta = batch_meta.select_fields(log_rollout_fields) - self._log_rollout_data(log_rollout_meta, reward_extra_infos_dict, timing_raw, rollout_data_dir) - - # TODO: validate - if self.config.trainer.test_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.test_freq == 0 - ): - with marked_timer("testing", timing_raw, color="green"): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. - esi_close_to_expiration = should_save_ckpt_esi( - max_steps_duration=self.max_steps_duration, - redundant_time=self.config.trainer.esi_redundant_time, - ) - # Check if the conditions for saving a checkpoint are met. - # The conditions include a mandatory condition (1) and - # one of the following optional conditions (2/3/4): - # 1. The save frequency is set to a positive value. - # 2. It's the last training step. - # 3. The current step number is a multiple of the save frequency. - # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. - if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration - ): - if esi_close_to_expiration: - print("Force saving checkpoint: ESI instance expiration approaching.") - with marked_timer("save_checkpoint", timing_raw, color="green"): - self._save_checkpoint() - - with marked_timer("stop_profile", timing_raw): - next_step_profile = ( - self.global_steps + 1 in self.config.global_profiler.steps - if self.config.global_profiler.steps is not None - else False - ) - self._stop_profiling( - curr_step_profile and not next_step_profile - if self.config.global_profiler.profile_continuous_steps - else curr_step_profile - ) - prev_step_profile = curr_step_profile - curr_step_profile = next_step_profile - - steps_duration = timing_raw["step"] - self.max_steps_duration = max(self.max_steps_duration, steps_duration) - - # training metrics - metrics.update( - { - "training/global_step": self.global_steps, - "training/epoch": epoch, - } - ) - # collect metrics - compute_data_metrics_fields = [ - "token_level_rewards", - "token_level_scores", - "advantages", - "returns", - "responses", - "attention_mask", - "response_mask", - ] - if "__num_turns__" in batch_meta.field_names: - compute_data_metrics_fields.append("__num_turns__") - if "tool_call_counts" in batch_meta.field_names: - compute_data_metrics_fields.append("tool_call_counts") - compute_data_metrics_meta = batch_meta.select_fields(compute_data_metrics_fields) - compute_data_metrics_meta.reorder(balanced_idx) - metrics.update( - compute_data_metrics_decorated(batch=compute_data_metrics_meta, use_critic=self.use_critic) - ) - - compute_timing_metrics_fields = ["responses", "attention_mask"] - compute_timing_metrics_meta = batch_meta.select_fields(compute_timing_metrics_fields) - compute_timing_metrics_meta.reorder(balanced_idx) - metrics.update( - compute_timing_metrics_decorated(batch=compute_timing_metrics_meta, timing_raw=timing_raw) - ) - - compute_throughout_metrics_meta = BatchMeta( - samples=[], - extra_info={"global_token_num": batch_meta.get_extra_info("global_token_num")}, - ) - # TODO: implement actual tflpo and theoretical tflpo - n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update( - compute_throughout_metrics_decorated( - batch=compute_throughout_metrics_meta, timing_raw=timing_raw, n_gpus=n_gpus - ) - ) - - # this is experimental and may be changed/removed in the future in favor of a general-purpose one - if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): - # TODO (TQ) :support transfer queue - self.train_dataloader.sampler.update(batch=batch) - - self.tq_client.clear_samples(batch_meta) - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - progress_bar.update(1) - self.global_steps += 1 - - if ( - hasattr(self.config.actor_rollout_ref.actor, "profiler") - and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" - ): - self.actor_rollout_wg.dump_memory_snapshot( - tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" - ) - - if is_last_step: - pprint(f"Final validation metrics: {last_val_metrics}") - progress_bar.close() - return - - # this is experimental and may be changed/removed in the future - # in favor of a general-purpose data buffer pool - if hasattr(self.train_dataset, "on_batch_end"): - # The dataset may be changed after each training batch - # TODO (TQ): support transfer queue - self.train_dataset.on_batch_end(batch=batch) diff --git a/verl/experimental/transfer_queue/run_qwen3-8b_transferqueue.sh b/verl/experimental/transfer_queue/run_qwen3-8b_transferqueue.sh deleted file mode 100644 index bd6d09e32d7..00000000000 --- a/verl/experimental/transfer_queue/run_qwen3-8b_transferqueue.sh +++ /dev/null @@ -1,70 +0,0 @@ -set -x - -MODEL_PATH="/workspace/models/Qwen3-8B" -TRAIN_FILE="/workspace/datasets/preprocessed/gsm8k/train.parquet" -TEST_FILE="/workspace/datasets/preprocessed/gsm8k/test.parquet" - -log_dir="./logs" -mkdir -p ${log_dir} -timestamp=$(date +"%Y%m%d%H%M%S") -log_file="${log_dir}/qwen3-8b_tq_${timestamp}.log" - -# You may try to enable zero-copy serialization for TransferQueue when using SimpleStorageUnit backend. -export TQ_ZERO_COPY_SERIALIZATION=False - -rollout_mode="async" -rollout_name="vllm" # sglang or vllm -if [ "$rollout_mode" = "async" ]; then - export VLLM_USE_V1=1 - return_raw_chat="True" -fi - -# You may also refer to tests/special_e2e/run_transferqueue.sh for more demo scripts - -python3 -m verl.experimental.transfer_queue.main_ppo \ - --config-name='transfer_queue_ppo_trainer' \ - algorithm.adv_estimator=grpo \ - data.train_files=${TRAIN_FILE} \ - data.val_files=${TEST_FILE} \ - data.return_raw_chat=$return_raw_chat \ - data.train_batch_size=128 \ - data.max_prompt_length=2048 \ - data.max_response_length=8192 \ - data.filter_overlong_prompts_workers=128 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=${MODEL_PATH} \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.max_num_batched_tokens=10240 \ - actor_rollout_ref.rollout.name=$rollout_name \ - actor_rollout_ref.rollout.mode=$rollout_mode \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen3_8b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=1000 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=2 \ - trainer.val_before_train=False \ - 2>&1 | tee "$log_file" -echo "Finished, log is saved in: $log_file" \ No newline at end of file diff --git a/verl/experimental/vla/rob_ray_trainer.py b/verl/experimental/vla/rob_ray_trainer.py index 619c822c75d..519b7dab4b5 100644 --- a/verl/experimental/vla/rob_ray_trainer.py +++ b/verl/experimental/vla/rob_ray_trainer.py @@ -600,13 +600,10 @@ def _validate(self): # pad to be divisible by dp_size size_divisor = self.config.env.train.num_envs * self.config.env.rollout.pipeline_stage_num test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) - if not self.async_rollout_mode: - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) - else: - reset_future = self._reset_envs(test_gen_batch_padded) - test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences( - test_gen_batch_padded, reset_future - ) + reset_future = self._reset_envs(test_gen_batch_padded) + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences( + test_gen_batch_padded, reset_future + ) # unpad test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) diff --git a/verl/experimental/vla/sac/sac_ray_trainer.py b/verl/experimental/vla/sac/sac_ray_trainer.py index cbdc1d597a2..310a9dc7686 100644 --- a/verl/experimental/vla/sac/sac_ray_trainer.py +++ b/verl/experimental/vla/sac/sac_ray_trainer.py @@ -528,13 +528,10 @@ def _validate(self): # pad to be divisible by dp_size size_divisor = self.config.env.train.num_envs * self.config.env.rollout.pipeline_stage_num test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) - if not self.async_rollout_mode: - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) - else: - reset_future = self._reset_envs(test_gen_batch_padded) - test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences( - test_gen_batch_padded, reset_future - ) + reset_future = self._reset_envs(test_gen_batch_padded) + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences( + test_gen_batch_padded, reset_future + ) # unpad test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) diff --git a/verl/models/mcore/mtp_patch.py b/verl/models/mcore/mtp_patch.py index 117b6e3f28c..fadf5b7bd52 100644 --- a/verl/models/mcore/mtp_patch.py +++ b/verl/models/mcore/mtp_patch.py @@ -20,11 +20,7 @@ import torch from megatron.core import parallel_state from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.transformer.multi_token_prediction import ( - MTPLossAutoScaler, - MTPLossLoggingHelper, - roll_tensor, -) +from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor try: from megatron.core.utils import unwrap_model @@ -78,19 +74,45 @@ def _megatron_gptmodel_postprocess( runtime_gather_output=None, extra_block_kwargs=None, inference_context=None, + **kwargs, ): - """Postprocesses decoder hidden states to generate logits or compute loss. + """Compatibility patch for GPTModel._postprocess. - Applies Multi-Token Prediction if enabled, generates output logits through - the output layer, and computes language model loss when labels are provided. + For inference (`labels is None`), delegate to the upstream implementation to stay + aligned with Megatron-Core updates. + + For training (`labels is not None`), keep VERL's MTP behavior and always return + logits (instead of CE loss) so PPO paths can compute custom losses from logits. """ + # Keep inference path aligned with whatever upstream Megatron currently expects. + if labels is None: + return self._postprocess_backup( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=mtp_in_postprocess, + loss_mask=loss_mask, + decoder_input=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, + **kwargs, + ) - # logits and loss + # Training path: keep logits for external loss computation. output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if mtp_in_postprocess and labels is not None: + if mtp_in_postprocess: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, @@ -109,60 +131,85 @@ def _megatron_gptmodel_postprocess( if not self.post_process: return hidden_states - # Skip when mtp_num_layers is None or 0 - if self.config.mtp_num_layers and labels is not None: - mtp_labels = labels.clone() - - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) - hidden_states = hidden_states_list[0] - if loss_mask is None: - # if loss_mask is not provided, use all ones as loss_mask - loss_mask = torch.ones_like(mtp_labels) - for mtp_layer_number in range(self.config.mtp_num_layers): - # Calc loss for the current Multi-Token Prediction (MTP) layers. - mtp_labels, _ = roll_tensor( - mtp_labels, - shifts=-1, - dims=-1, - cp_group=self.cp_group, + # Skip when mtp_num_layers is None or 0. + if self.config.mtp_num_layers: + cp_group = None + if getattr(self, "pg_collection", None) is not None: + cp_group = self.pg_collection.cp + elif hasattr(self, "cp_group"): + cp_group = self.cp_group + + # Prefer upstream helper when available (newer Megatron-LM). + try: + from megatron.core.transformer.multi_token_prediction import process_mtp_loss + + hidden_states = process_mtp_loss( + hidden_states=hidden_states, + labels=labels, + loss_mask=loss_mask, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + is_training=self.training, + compute_language_model_loss=self.compute_language_model_loss, + config=self.config, + cp_group=cp_group, packed_seq_params=packed_seq_params, ) - loss_mask, num_tokens = roll_tensor( - loss_mask, - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) - - # Compute mtp loss without storing logits to save memory. - mtp_loss = self.compute_output_layer_and_language_model_loss( - hidden_states_list[mtp_layer_number + 1], - labels=mtp_labels, - weight=self.shared_embedding_or_output_weight(), - sequence_parallel_enabled=self.output_layer.sequence_parallel, - column_parallel_linear=self.output_layer, - col_linear_kwargs={ - "weight": output_weight, - "runtime_gather_output": runtime_gather_output, - }, - ) + except (ImportError, AttributeError, TypeError): + # Fallback for older Megatron-LM versions without process_mtp_loss API. + mtp_labels = labels.clone() + + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( + mtp_labels, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + loss_mask, num_tokens = roll_tensor( + loss_mask, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) - mtp_loss = loss_mask * mtp_loss - if self.training: - # TODO(shifangx): remove the use of parallel_state here - # after moving loss logging to loss_func in pretrain_gpt.py - MTPLossLoggingHelper.save_loss_to_tracker( - torch.sum(mtp_loss) / num_tokens, - mtp_layer_number, - self.config.mtp_num_layers, - avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + # Compute mtp loss without storing logits to save memory. + mtp_loss = self.compute_output_layer_and_language_model_loss( + hidden_states_list[mtp_layer_number + 1], + labels=mtp_labels, + weight=self.shared_embedding_or_output_weight(), + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ + "weight": output_weight, + "runtime_gather_output": runtime_gather_output, + }, ) - mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers - if self.config.calculate_per_token_loss: - hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) - else: - hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) + + mtp_loss = loss_mask * mtp_loss + if self.training: + # TODO(shifangx): remove the use of parallel_state here + # after moving loss logging to loss_func in pretrain_gpt.py + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + mtp_layer_number, + self.config.mtp_num_layers, + avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + ) + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers + if self.config.calculate_per_token_loss: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) + else: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) # [s b h] => [b s h] diff --git a/verl/models/mcore/util.py b/verl/models/mcore/util.py index aefb798aa0b..dc35310c894 100644 --- a/verl/models/mcore/util.py +++ b/verl/models/mcore/util.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import math +import os import torch from megatron.core import parallel_state as mpu @@ -21,6 +23,9 @@ from verl.utils.model import CausalLMOutputForPPO +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + def preprocess_packed_seqs( input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True, use_fp8_padding=False @@ -333,6 +338,19 @@ def preprocess_thd_no_padding( start_idx = cu_seqlens_padded_cpu[i] // cp_size # split to 2 chunks d = input_ids[i] + # If the number of elements in `d` is smaller than the required + # alignment size, pad the tensor with zeros so that its total + # length matches `align_size`. This ensures size alignment for + # downstream operations (e.g., communication or memory alignment). + if d.numel() < align_size: + original_size = d.numel() + pad = torch.zeros(align_size - d.numel(), dtype=d.dtype, device=d.device) + d = torch.cat([d, pad], dim=0) + logger.warning_once( + f"Padding tensor for context parallel alignment, original_size={original_size}, " + f"align_size={align_size}" + ) + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) ] diff --git a/verl/protocol.py b/verl/protocol.py index 27a1f6a1f94..820b9bd3462 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -36,7 +36,7 @@ from torch.utils.data import DataLoader from verl.utils.device import get_device_id, get_torch_device -from verl.utils.py_functional import union_two_dict +from verl.utils.py_functional import list_of_dict_to_dict_of_list, union_two_dict from verl.utils.torch_functional import allgather_dict_tensors __all__ = ["DataProto", "union_tensor_dict"] @@ -198,18 +198,6 @@ def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str return tensor_dict1 -def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): - if len(list_of_dict) == 0: - return {} - keys = list_of_dict[0].keys() - output = {key: [] for key in keys} - for data in list_of_dict: - for key, item in data.items(): - assert key in output - output[key].append(item) - return output - - def fold_batch_dim(data: "DataProto", new_batch_size): """ Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx] diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 540c4e00552..c8438309879 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -20,7 +20,6 @@ from verl.protocol import DataProtoFuture, _padding_size_key from verl.utils.py_functional import DynamicEnum from verl.utils.tensordict_utils import chunk_tensordict, concat_tensordict, contiguous -from verl.utils.transferqueue_utils import BatchMeta # here we add a magic number of avoid user-defined function already have this attribute MAGIC_ATTR = "attrs_3141562937" @@ -80,7 +79,7 @@ def _split_args_kwargs_data_proto(chunks, *args, **kwargs): splitted_args = [] for arg in args: - assert isinstance(arg, DataProto | DataProtoFuture | BatchMeta | TensorDict) + assert isinstance(arg, DataProto | DataProtoFuture | TensorDict) if isinstance(arg, TensorDict): chunked_arg = chunk_tensordict(arg, chunks) chunked_arg = _consolidate_tuple_td(chunked_arg) @@ -91,7 +90,7 @@ def _split_args_kwargs_data_proto(chunks, *args, **kwargs): splitted_kwargs = {} for key, val in kwargs.items(): - assert isinstance(val, DataProto | DataProtoFuture | BatchMeta | TensorDict) + assert isinstance(val, DataProto | DataProtoFuture | TensorDict) if isinstance(val, TensorDict): chunked_kwarg = chunk_tensordict(val, chunks) chunked_kwarg = _consolidate_tuple_td(chunked_kwarg) @@ -165,8 +164,6 @@ def _concat_data_proto_or_future(output: list): return DataProto.concat(output) elif isinstance(o, ray.ObjectRef): return DataProtoFuture.concat(output) - elif isinstance(o, BatchMeta): - return BatchMeta.concat(output) elif isinstance(o, TensorDict): return concat_tensordict(output) else: @@ -288,8 +285,8 @@ def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output) from verl.protocol import DataProto for o in output: - assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta | TensorDict), ( - f"expecting {o} to be DataProto | ray.ObjectRef | BatchMeta | TensorDict, but got {type(o)}" + assert isinstance(o, DataProto | ray.ObjectRef | TensorDict), ( + f"expecting {o} to be DataProto | ray.ObjectRef | TensorDict, but got {type(o)}" ) return _concat_data_proto_or_future(output) @@ -447,14 +444,11 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki A decorator that wraps the original function with distributed execution configuration. """ - from verl.utils.transferqueue_utils import tqbridge _check_dispatch_mode(dispatch_mode=dispatch_mode) _check_execute_mode(execute_mode=execute_mode) def decorator(func): - func = tqbridge(dispatch_mode=dispatch_mode)(func) - @wraps(func) def inner(*args, **kwargs): if materialize_futures: diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index cffaf5d30a0..6fc955cb049 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -165,15 +165,6 @@ def set_dispatch_collect(self, mesh_name: str, dispatch_dp_rank: dict[str, int], for is_collect in collect_dp_rank.values(): self.__collect_dp_rank[mesh_name] = is_collect - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) - def create_transferqueue_client(self, config): - from verl.utils.transferqueue_utils import create_transferqueue_client - - create_transferqueue_client( - client_id=f"worker_{self.rank}", - config=config.transfer_queue, - ) - @classmethod def env_keys(cls): """The keys of the environment variables that are used to configure the Worker.""" diff --git a/verl/trainer/README.md b/verl/trainer/README.md new file mode 100644 index 00000000000..d881d689b58 --- /dev/null +++ b/verl/trainer/README.md @@ -0,0 +1,16 @@ +# verl Main Entrypoints + +## SFT Trainer +- sft_trainer.py: SFT trainer based on model engine, support various backends: fsdp, megatron, veomni, torchtitan. Launched by `torchrun` and run in multi-controller mode. +- **[EXPERIMENTAL]** sft_trainer_ray.py: SFT trainer based on model engine with single-controller mode. Launched by ray with a driver process coordinating multiple worker processes. + +## RL Trainer +|trainer|description|sync/async|trainer/rollout|partial rollout| +|----|----|----|----|----| +|main_ppo.py|rollout until a batch is completed, then train|synchronous|colocated|No| +|TBD|[kimi-1.5](https://arxiv.org/pdf/2501.12599) style trainer: streaming rollout with capped length partial rollout|asynchronous|colocated|Yes| +|TBD|[Areal](https://arxiv.org/pdf/2505.24298) style trainer: fully decoupled trainer and rollout with staleness control|asynchronous|disaggregated|Yes| + +## Inference and Evaluation +- main_generation_server.py: Launch standalone servers and generate responses for a specified prompt dataset. +- main_eval.py: Evaluate the performance of generated responses with reward function on a specified prompt dataset. diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index a6e80b62cd2..09391ec6af3 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -216,6 +216,8 @@ actor_rollout_ref: _target_: verl.workers.config.RolloutConfig name: ??? mode: async + nnodes: 0 + n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8} temperature: 1.0 top_k: -1 top_p: 1 @@ -290,6 +292,8 @@ actor_rollout_ref: engine_kwargs: {} trace: _target_: verl.workers.config.TraceConfig + project_name: ${oc.select:trainer.project_name,null} + experiment_name: ${oc.select:trainer.experiment_name,null} backend: null token2text: false max_samples_per_step_per_worker: null @@ -324,6 +328,7 @@ actor_rollout_ref: quantization: null quantization_config_file: null mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + qat: ${oc.select:actor_rollout_ref.actor.qat,null} layer_name_map: qkv_layer_name: qkv gate_proj_layer_name: gate_up @@ -585,6 +590,10 @@ reward_model: reward_loop_source: null reward_loop_module_path: null reward_loop_class_name: null + model: + path: null + external_lib: null + trust_remote_code: null rollout: name: null dtype: null diff --git a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml new file mode 100644 index 00000000000..b923da853ec --- /dev/null +++ b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml @@ -0,0 +1,677 @@ +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job model_engine=torchtitan' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. + +actor_rollout_ref: + actor: + optim: + _target_: verl.workers.config.TorchtitanOptimizerConfig + name: AdamW + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + eps: 1.0e-08 + decay_type: linear + min_lr_factor: 0.0 + torchtitan: + _target_: verl.workers.config.TorchtitanEngineConfig + param_offload: false + optimizer_offload: false + wrap_policy: + min_num_params: 0 + reshard_after_forward: default + forward_prefetch: false + use_orig_params: false + mixed_precision: false + use_torch_compile: true + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + data_parallel_size: 1 + data_parallel_replicate_size: 1 + data_parallel_shard_size: 1 + tensor_parallel_size: 1 + expert_parallel_size: 1 + pipeline_parallel_size: 1 + context_parallel_size: 1 + attn_type: flex + strategy: torchtitan + seed: 42 + full_determinism: false + forward_only: false + dtype: bfloat16 + _target_: verl.workers.config.TorchTitanActorConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: torchtitan + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + tau_pos: 1.0 + tau_neg: 1.05 + freeze_vision_tower: false + policy_loss: + _target_: verl.workers.config.PolicyLossConfig + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + loss_scale_factor: null + entropy_coeff: 0 + calculate_entropy: false + use_kl_loss: false + use_prefix_grouper: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + data_loader_seed: 42 + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + mbridge_config: {} + use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + ref: + optim: + _target_: verl.workers.config.TorchtitanOptimizerConfig + name: AdamW + lr: 0.001 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + eps: 1.0e-08 + decay_type: linear + min_lr_factor: 0.0 + torchtitan: + _target_: verl.workers.config.TorchtitanEngineConfig + param_offload: false + optimizer_offload: false + wrap_policy: + min_num_params: 0 + reshard_after_forward: default + forward_prefetch: false + use_orig_params: false + mixed_precision: false + use_torch_compile: true + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + data_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_size,1} + data_parallel_replicate_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_replicate_size,1} + data_parallel_shard_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_shard_size,1} + tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.tensor_parallel_size,1} + expert_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.expert_parallel_size,1} + pipeline_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.pipeline_parallel_size,1} + context_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.context_parallel_size,1} + attn_type: ${oc.select:actor_rollout_ref.actor.torchtitan.attn_type,flex} + strategy: torchtitan + seed: ${oc.select:actor_rollout_ref.actor.torchtitan.seed,42} + full_determinism: false + forward_only: true + dtype: bfloat16 + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: torchtitan + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + _target_: verl.workers.config.TorchTitanActorConfig + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + mode: async + nnodes: 0 + n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8} + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + ignore_eos: false + enforce_eager: false + cudagraph_capture_sizes: null + free_cache_engine: true + tensor_model_parallel_size: 2 + data_parallel_size: 1 + expert_parallel_size: 1 + pipeline_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + enable_chunked_prefill: true + enable_prefix_caching: true + logprobs_mode: processed_logprobs + scheduling_policy: fcfs + load_format: dummy + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + over_sample_rate: 0 + multi_stage_wake_up: false + engine_kwargs: + vllm: {} + sglang: {} + trtllm: {} + val_kwargs: + _target_: verl.workers.config.SamplingConfig + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false + multi_turn: + _target_: verl.workers.config.MultiTurnConfig + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + num_repeat_rollouts: null + calculate_log_probs: false + agent: + _target_: verl.workers.config.AgentLoopConfig + num_workers: 8 + default_agent_loop: single_turn_agent + agent_loop_config_path: null + custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig + path: null + name: null + checkpoint_engine: + _target_: verl.workers.config.CheckpointEngineConfig + backend: naive + update_weights_bucket_megabytes: 2048 + engine_kwargs: {} + trace: + _target_: verl.workers.config.TraceConfig + project_name: ${oc.select:trainer.project_name,null} + experiment_name: ${oc.select:trainer.experiment_name,null} + backend: null + token2text: false + max_samples_per_step_per_worker: null + skip_rollout: false + skip_dump_dir: /tmp/rollout_dump + skip_tokenizer_init: true + enable_rollout_routing_replay: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.contents,[]} + level: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.level,level0} + analysis: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.analysis,false} + discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.discrete,false} + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.contents,[]} + discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.discrete,false} + prometheus: + _target_: verl.workers.config.PrometheusConfig + enable: false + port: 9090 + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + quantization: null + quantization_config_file: null + mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + qat: ${oc.select:actor_rollout_ref.actor.qat,null} + layered_summon: false + model: + _target_: verl.workers.config.HFModelConfig + path: ~/models/deepseek-llm-7b-chat + hf_config_path: null + tokenizer_path: null + use_shm: false + trust_remote_code: false + custom_chat_template: null + external_lib: null + override_config: {} + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: true + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + lora_adapter_path: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + tiled_mlp: + enabled: false + num_shards: 4 + mtp: + _target_: verl.workers.config.MtpConfig + enable: false + enable_train: false + enable_rollout: false + detach_encoder: false + mtp_loss_scaling_factor: 0.1 + speculative_algorithm: EAGLE + speculative_num_steps: 3 + speculative_eagle_topk: 1 + speculative_num_draft_tokens: 4 + method: mtp + num_speculative_tokens: 1 + hybrid_engine: true + nccl_timeout: 600 +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, + null} + return_raw_input_ids: false + return_raw_chat: true + return_full_prompt: false + shuffle: true + seed: null + dataloader_num_workers: 8 + image_patch_size: 14 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null + apply_chat_template_kwargs: {} +critic: + optim: + _target_: verl.workers.config.TorchtitanOptimizerConfig + name: AdamW + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + eps: 1.0e-08 + decay_type: linear + min_lr_factor: 0.0 + torchtitan: + _target_: verl.workers.config.TorchtitanEngineConfig + param_offload: false + optimizer_offload: false + wrap_policy: + min_num_params: 0 + reshard_after_forward: default + forward_prefetch: false + use_orig_params: false + mixed_precision: false + use_torch_compile: true + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + data_parallel_size: 1 + data_parallel_replicate_size: 1 + data_parallel_shard_size: 1 + tensor_parallel_size: 1 + expert_parallel_size: 1 + pipeline_parallel_size: 1 + context_parallel_size: 1 + attn_type: flex + strategy: torchtitan + seed: 42 + full_determinism: false + forward_only: false + dtype: bfloat16 + _target_: verl.workers.config.TorchTitanCriticConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: torchtitan + enable: null + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: {} + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + _target_: verl.trainer.config.BaseModelConfig + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null} + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + mbridge_config: {} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} +custom_reward_function: + path: null + name: null +reward_model: + num_workers: null + reward_manager: null + enable: null + enable_resource_pool: null + n_gpus_per_node: null + nnodes: null + reward_loop_source: null + reward_loop_module_path: null + reward_loop_class_name: null + model: + path: null + external_lib: null + trust_remote_code: null + rollout: + name: null + dtype: null + gpu_memory_utilization: null + enforce_eager: null + cudagraph_capture_sizes: null + free_cache_engine: null + data_parallel_size: null + expert_parallel_size: null + tensor_model_parallel_size: null + max_num_batched_tokens: null + max_model_len: null + max_num_seqs: null + load_format: null + engine_kwargs: null + limit_images: null + enable_chunked_prefill: null + enable_prefix_caching: null + disable_log_stats: null + skip_tokenizer_init: null + prompt_length: null + response_length: null +sandbox_fusion: + url: null + max_concurrent: null + memory_limit_mb: null +reward: + num_workers: 8 + custom_reward_function: + path: null + name: compute_score + reward_manager: + _target_: verl.workers.config.reward_model.RewardManagerConfig + source: register + name: naive + module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager + reward_model: + enable: false + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + model_path: null + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + dtype: bfloat16 + gpu_memory_utilization: 0.5 + enforce_eager: true + cudagraph_capture_sizes: null + free_cache_engine: true + data_parallel_size: 1 + expert_parallel_size: 1 + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + load_format: auto + engine_kwargs: {} + limit_images: null + enable_chunked_prefill: true + enable_prefix_caching: true + disable_log_stats: true + skip_tokenizer_init: false + prompt_length: 2048 + response_length: 2048 + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 +algorithm: + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + bypass_mode: false + loss_type: ppo_clip + rollout_is_batch_normalize: false + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + reweight_method: pow + weight_pow: 2.0 +trainer: + balance_batch: true + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + rollout_data_dir: null + validation_data_dir: null + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + val_before_train: true + val_only: false + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + del_local_ckpt_after_load: false + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda + use_legacy_worker_impl: auto +global_profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: null + steps: null + profile_continuous_steps: false + save_path: outputs/profile + global_tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: false + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + torch_memory: + trace_alloc_max_entries: 100000 + stack_depth: 32 + context: all + stacks: all + kw_args: {} +transfer_queue: + enable: false +ray_kwargs: + ray_init: + num_cpus: null + timeline_json_file: null diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 555cea354ae..1cdc21b1ec8 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -21,6 +21,7 @@ actor_rollout_ref: min_lr_ratio: 0.0 num_cycles: 0.5 lr_scheduler_type: constant + zero_indexed_step: true warmup_style: null override_optimizer_config: null fsdp_config: @@ -126,6 +127,16 @@ actor_rollout_ref: use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} calculate_sum_pi_squared: false sum_pi_squared_checkpointing: false + qat: + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null ref: rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} strategy: ${actor_rollout_ref.actor.strategy} @@ -193,6 +204,8 @@ actor_rollout_ref: _target_: verl.workers.config.RolloutConfig name: ??? mode: async + nnodes: 0 + n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8} temperature: 1.0 top_k: -1 top_p: 1 @@ -267,6 +280,8 @@ actor_rollout_ref: engine_kwargs: {} trace: _target_: verl.workers.config.TraceConfig + project_name: ${oc.select:trainer.project_name,null} + experiment_name: ${oc.select:trainer.experiment_name,null} backend: null token2text: false max_samples_per_step_per_worker: null @@ -301,6 +316,7 @@ actor_rollout_ref: quantization: null quantization_config_file: null mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + qat: ${oc.select:actor_rollout_ref.actor.qat,null} layered_summon: false model: _target_: verl.workers.config.HFModelConfig @@ -399,6 +415,7 @@ critic: min_lr_ratio: 0.0 num_cycles: 0.5 lr_scheduler_type: constant + zero_indexed_step: true warmup_style: null override_optimizer_config: null model: @@ -505,6 +522,10 @@ reward_model: reward_loop_source: null reward_loop_module_path: null reward_loop_class_name: null + model: + path: null + external_lib: null + trust_remote_code: null rollout: name: null dtype: null diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml index 80ffe2b68fc..ccaf6582902 100644 --- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml @@ -26,14 +26,9 @@ actor_rollout_ref: _target_: verl.workers.config.VeOmniEngineConfig param_offload: false optimizer_offload: false - data_parallel_size: 1 - data_parallel_replicate_size: 1 - data_parallel_shard_size: 1 - tensor_parallel_size: 1 - expert_parallel_size: 1 - pipeline_parallel_size: 1 - context_parallel_size: 1 + fsdp_size: -1 ulysses_parallel_size: 1 + expert_parallel_size: 1 mixed_precision: true seed: 42 full_determinism: false @@ -167,14 +162,9 @@ actor_rollout_ref: _target_: verl.workers.config.VeOmniEngineConfig param_offload: ${oc.select:actor_rollout_ref.actor.veomni.param_offload,False} optimizer_offload: false - data_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_size,1} - data_parallel_replicate_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_replicate_size,1} - data_parallel_shard_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_shard_size,1} - tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.tensor_parallel_size,1} - expert_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.expert_parallel_size,1} - pipeline_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.pipeline_parallel_size,1} - context_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.context_parallel_size,1} + fsdp_size: ${oc.select:actor_rollout_ref.actor.veomni.fsdp_size,-1} ulysses_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.ulysses_parallel_size,1} + expert_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.expert_parallel_size,1} mixed_precision: true seed: ${oc.select:actor_rollout_ref.actor.veomni.seed,42} full_determinism: false @@ -196,6 +186,8 @@ actor_rollout_ref: _target_: verl.workers.config.RolloutConfig name: ??? mode: async + nnodes: 0 + n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8} temperature: 1.0 top_k: -1 top_p: 1 @@ -270,6 +262,8 @@ actor_rollout_ref: engine_kwargs: {} trace: _target_: verl.workers.config.TraceConfig + project_name: ${oc.select:trainer.project_name,null} + experiment_name: ${oc.select:trainer.experiment_name,null} backend: null token2text: false max_samples_per_step_per_worker: null @@ -304,6 +298,7 @@ actor_rollout_ref: quantization: null quantization_config_file: null mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + qat: ${oc.select:actor_rollout_ref.actor.qat,null} layered_summon: false model: _target_: verl.workers.config.HFModelConfig @@ -407,14 +402,9 @@ critic: _target_: verl.workers.config.VeOmniEngineConfig param_offload: false optimizer_offload: false - data_parallel_size: 1 - data_parallel_replicate_size: 1 - data_parallel_shard_size: 1 - tensor_parallel_size: 1 - expert_parallel_size: 1 - pipeline_parallel_size: 1 - context_parallel_size: 1 + fsdp_size: -1 ulysses_parallel_size: 1 + expert_parallel_size: 1 mixed_precision: true seed: 42 full_determinism: false @@ -500,6 +490,10 @@ reward_model: reward_loop_source: null reward_loop_module_path: null reward_loop_class_name: null + model: + path: null + external_lib: null + trust_remote_code: null rollout: name: null dtype: null diff --git a/verl/trainer/config/actor/dp_actor.yaml b/verl/trainer/config/actor/dp_actor.yaml index fc0a16be609..7fbe49c019e 100644 --- a/verl/trainer/config/actor/dp_actor.yaml +++ b/verl/trainer/config/actor/dp_actor.yaml @@ -48,3 +48,35 @@ calculate_sum_pi_squared: False # Enable gradient checkpointing for sum_pi_squared computation (saves memory) sum_pi_squared_checkpointing: False + +# QAT (Quantization-Aware Training) configuration +# When enabled: +# - QAT is automatically applied to actor model during training +# - Fused scales (QKV/GateUp) are automatically enabled for training-inference consistency +# - Fast quantization is used when syncing weights to vLLM rollout +# Supported modes: "w4a16" (NVFP4 weight-only) +# Note: "w4a4" mode is included in the code but currently has KL divergence issues and is NOT recommended for use. +# For usage examples, see: https://github.com/verl-project/verl-recipe/blob/main/qat/README.md +qat: + + # Whether to enable QAT + enable: false + + # Quantization mode: "w4a16" (weight-only). "w4a4" is experimental and not recommended. + mode: "w4a16" + + # Quantization group size (NVFP4 requires 16) + group_size: 16 + + # Patterns to ignore (e.g., lm_head, embed_tokens) + ignore_patterns: + + - "lm_head" + - "embed_tokens" + - "re:.*mlp.gate$" + + # Activation observer for W4A4 mode: "static_minmax", "memoryless_minmax", or "minmax" + activation_observer: "static_minmax" + + # Path to vLLM quantization config JSON file + quantization_config_path: null diff --git a/verl/trainer/config/actor/torchtitan_actor.yaml b/verl/trainer/config/actor/torchtitan_actor.yaml new file mode 100644 index 00000000000..3bf25b9a5a7 --- /dev/null +++ b/verl/trainer/config/actor/torchtitan_actor.yaml @@ -0,0 +1,16 @@ +# torchtitan actor config, inheriting from trainer/config/actor/actor.yaml +defaults: + # torchtitan optimizer config + - ../optim@optim: torchtitan + + # torchtitan engine config + - ../engine@torchtitan: torchtitan + + - actor + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +_target_: verl.workers.config.TorchTitanActorConfig + +strategy: torchtitan diff --git a/verl/trainer/config/critic/torchtitan_critic.yaml b/verl/trainer/config/critic/torchtitan_critic.yaml new file mode 100644 index 00000000000..4fafbd9d227 --- /dev/null +++ b/verl/trainer/config/critic/torchtitan_critic.yaml @@ -0,0 +1,28 @@ +# defaults specify the default config from each component +defaults: + + # torchtitan optimizer config + - ../optim@optim: torchtitan + + # torchtitan engine config + - ../engine@torchtitan: torchtitan + + # critic config, inheriting from trainer/config/critic/critic.yaml + - critic + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +_target_: verl.workers.config.TorchTitanCriticConfig + +strategy: torchtitan + +# model config for the critic +model: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.BaseModelConfig + +# seed for data loader +data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null} diff --git a/verl/trainer/config/engine/torchtitan.yaml b/verl/trainer/config/engine/torchtitan.yaml new file mode 100644 index 00000000000..0f75004c0a4 --- /dev/null +++ b/verl/trainer/config/engine/torchtitan.yaml @@ -0,0 +1,74 @@ +# Target class for this configuration +_target_: verl.workers.config.TorchtitanEngineConfig + +# Whether to offload model parameters to CPU +param_offload: False + +# Whether to offload optimizer state to CPU +optimizer_offload: False + +# policy for wrapping the model +wrap_policy: + # Minimum number of parameters to trigger wrapping a layer with FSDP + min_num_params: 0 + +# The policy for applying `reshard_after_forward` within an FSDP setup +# Options: "default", "always", "never" +reshard_after_forward: default + +# Prefetch the next forward-pass all-gather before the current forward computation. +forward_prefetch: false + +# Whether to use original parameters +use_orig_params: false + +# Mixed precision configuration for FSDP +mixed_precision: false + +# Whether to use torch compile +use_torch_compile: true + +# Whether to use entropy_from_logits_with_chunking +entropy_from_logits_with_chunking: false + +# Whether to use entropy checkpointing +entropy_checkpointing: false + +# Data parallel size (FSDP group size) +data_parallel_size: 1 + +# Data parallel replicate size +data_parallel_replicate_size: 1 + +# Data parallel shard size +data_parallel_shard_size: 1 + +# Tensor parallel size +tensor_parallel_size: 1 + +# Expert parallel size +expert_parallel_size: 1 + +# Pipeline parallel size +pipeline_parallel_size: 1 + +# Context parallel size +context_parallel_size: 1 + +# Attention type for torchtitan's model (e.g., "sdpa", "flex", "varlen") +attn_type: flex + +# Strategy +strategy: torchtitan + +# Random seed for reproducibility +seed: 42 + +# Whether to enable full determinism for distributed training, only for debugging +full_determinism: false + +# Whether to use forward only +forward_only: false + +# Mixed precision training param dtype +dtype: bfloat16 diff --git a/verl/trainer/config/engine/veomni.yaml b/verl/trainer/config/engine/veomni.yaml index 3bbbcf55e60..1869dd745f9 100644 --- a/verl/trainer/config/engine/veomni.yaml +++ b/verl/trainer/config/engine/veomni.yaml @@ -7,22 +7,13 @@ param_offload: False # Whether to offload optimizer state to CPU optimizer_offload: False -data_parallel_size: 1 +# FSDP group size. -1 means use all available GPUs. +fsdp_size: -1 -data_parallel_replicate_size: 1 - -data_parallel_shard_size: 1 - -tensor_parallel_size: 1 +ulysses_parallel_size: 1 expert_parallel_size: 1 -pipeline_parallel_size: 1 - -context_parallel_size: 1 - -ulysses_parallel_size: 1 - mixed_precision: true # Random seed for reproducibility. diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml deleted file mode 100644 index 478733339ce..00000000000 --- a/verl/trainer/config/generation.yaml +++ /dev/null @@ -1,62 +0,0 @@ -trainer: - nnodes: 1 - n_gpus_per_node: 8 - device: cuda - -data: - path: ~/data/rlhf/math/test.parquet - prompt_key: prompt - n_samples: 5 - output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet - batch_size: 128 - -model: - path: ~/models/Qwen2-7B-Instruct - external_lib: null -rollout: - _target_: verl.workers.config.RolloutConfig - name: vllm - # NOTE: 'sync' mode was removed in PR #4411. Only 'async' mode is supported. - # WARNING: The main_generation.py workflow is currently broken for vLLM async rollout - # as it requires synchronous generate_sequences() which vLLMAsyncRollout doesn't support. - # See issue #4682 for discussion and workarounds. - mode: async - temperature: 1.0 - top_k: 50 # 0 for hf rollout, -1 for vllm rollout - top_p: 0.7 - prompt_length: 1536 - response_length: 512 - # for vllm rollout - dtype: bfloat16 # should align with FSDP - gpu_memory_utilization: 0.5 - ignore_eos: False - enforce_eager: True - free_cache_engine: True - load_format: auto - tensor_model_parallel_size: 1 - data_parallel_size: 1 - max_num_batched_tokens: 8192 - max_model_len: null - max_num_seqs: 1024 - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: 8 - # for hf rollout - do_sample: True - disable_log_stats: True - enable_chunked_prefill: True - n: 1 - # support logging rollout prob for debugging purpose - calculate_log_probs: False -actor: - strategy: fsdp # This is for backward-compatibility - ulysses_sequence_parallel_size: 1 # sp size - entropy_from_logits_with_chunking: False # calculate entropy with chunking to reduce memory peak - entropy_checkpointing: False # recompute entropy - fsdp_config: - fsdp_size: -1 - forward_prefetch: False # FSDP1 forward_prefetch configuration - -ray_kwargs: - ray_init: - num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. - timeline_json_file: null diff --git a/verl/trainer/config/legacy_reward_impl.yaml b/verl/trainer/config/legacy_reward_impl.yaml index 476a6e2810f..e1266ba9a89 100644 --- a/verl/trainer/config/legacy_reward_impl.yaml +++ b/verl/trainer/config/legacy_reward_impl.yaml @@ -12,6 +12,10 @@ reward_model: reward_loop_source: null reward_loop_module_path: null reward_loop_class_name: null + model: + path: null + external_lib: null + trust_remote_code: null rollout: name: null dtype: null diff --git a/verl/trainer/config/model_engine/torchtitan.yaml b/verl/trainer/config/model_engine/torchtitan.yaml new file mode 100644 index 00000000000..32a4b3d0611 --- /dev/null +++ b/verl/trainer/config/model_engine/torchtitan.yaml @@ -0,0 +1,2 @@ +# @package _global_ +model_engine: torchtitan diff --git a/verl/trainer/config/optim/fsdp.yaml b/verl/trainer/config/optim/fsdp.yaml index a7dd99b1ee2..ce6ced773b6 100644 --- a/verl/trainer/config/optim/fsdp.yaml +++ b/verl/trainer/config/optim/fsdp.yaml @@ -38,6 +38,9 @@ num_cycles: 0.5 # LR scheduler type: "constant" or "cosine" lr_scheduler_type: constant +# Whether the LR schedule uses 0-indexed steps +zero_indexed_step: true + # deprecated warmup_style: null diff --git a/verl/trainer/config/optim/torchtitan.yaml b/verl/trainer/config/optim/torchtitan.yaml new file mode 100644 index 00000000000..baea31ee527 --- /dev/null +++ b/verl/trainer/config/optim/torchtitan.yaml @@ -0,0 +1,35 @@ +# Target class for this configuration +_target_: verl.workers.config.TorchtitanOptimizerConfig + +# Optimizer name +name: AdamW + +# Learning rate +lr: 1e-3 + +# LR warmup steps ratio +lr_warmup_steps_ratio: 0.0 + +# Total training steps +total_training_steps: -1 + +# Weight decay +weight_decay: 0.01 + +# LR warmup steps +lr_warmup_steps: -1 + +# Betas for Adam optimizer +betas: [0.9, 0.999] + +# Clip gradient +clip_grad: 1.0 + +# Epsilon for Adam optimizer +eps: 1e-8 + +# Decay type: "linear", "sqrt", or "cosine" +decay_type: linear + +# Minimum LR factor for cosine schedule +min_lr_factor: 0.0 diff --git a/verl/trainer/config/ref/torchtitan_ref.yaml b/verl/trainer/config/ref/torchtitan_ref.yaml new file mode 100644 index 00000000000..80cc01b1098 --- /dev/null +++ b/verl/trainer/config/ref/torchtitan_ref.yaml @@ -0,0 +1,28 @@ +# torchtitan ref config, inheriting from trainer/config/ref/ref.yaml +defaults: + # torchtitan optimizer config + - ../optim@optim: torchtitan + + # torchtitan engine config + - ../engine@torchtitan: torchtitan + + - ref + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +_target_: verl.workers.config.TorchTitanActorConfig + +strategy: torchtitan + +torchtitan: + seed: ${oc.select:actor_rollout_ref.actor.torchtitan.seed,42} + data_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_size,1} + data_parallel_replicate_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_replicate_size,1} + data_parallel_shard_size: ${oc.select:actor_rollout_ref.actor.torchtitan.data_parallel_shard_size,1} + tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.tensor_parallel_size,1} + expert_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.expert_parallel_size,1} + pipeline_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.pipeline_parallel_size,1} + context_parallel_size: ${oc.select:actor_rollout_ref.actor.torchtitan.context_parallel_size,1} + attn_type: ${oc.select:actor_rollout_ref.actor.torchtitan.attn_type,flex} + forward_only: True diff --git a/verl/trainer/config/ref/veomni_ref.yaml b/verl/trainer/config/ref/veomni_ref.yaml index f52421fd39f..738962356fe 100644 --- a/verl/trainer/config/ref/veomni_ref.yaml +++ b/verl/trainer/config/ref/veomni_ref.yaml @@ -14,14 +14,9 @@ strategy: veomni veomni: seed: ${oc.select:actor_rollout_ref.actor.veomni.seed,42} - data_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_size,1} - data_parallel_replicate_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_replicate_size,1} - data_parallel_shard_size: ${oc.select:actor_rollout_ref.actor.veomni.data_parallel_shard_size,1} - tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.tensor_parallel_size,1} - expert_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.expert_parallel_size,1} - pipeline_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.pipeline_parallel_size,1} - context_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.context_parallel_size,1} + fsdp_size: ${oc.select:actor_rollout_ref.actor.veomni.fsdp_size,-1} ulysses_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.ulysses_parallel_size,1} + expert_parallel_size: ${oc.select:actor_rollout_ref.actor.veomni.expert_parallel_size,1} param_offload: ${oc.select:actor_rollout_ref.actor.veomni.param_offload,False} attn_implementation: ${oc.select:actor_rollout_ref.actor.veomni.attn_implementation,flash_attention_2} moe_implementation: ${oc.select:actor_rollout_ref.actor.veomni.moe_implementation,fused} diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index 575c27551fa..894538d1d87 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -7,6 +7,12 @@ name: ??? # sync: LLM, async: AsyncLLM mode: async +# Number of nodes for standalone rollout server, must be > 0 in one-step-off/fully async training. +nnodes: 0 + +# Number of GPUs per node for rollout server. +n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8} + # Sampling temperature for rollout. temperature: 1.0 @@ -273,6 +279,12 @@ trace: # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs _target_: verl.workers.config.TraceConfig + # Project name for experiment tracking (e.g., wandb) + project_name: ${oc.select:trainer.project_name,null} + + # Experiment name for run identification in tracking tools + experiment_name: ${oc.select:trainer.experiment_name,null} + # trace backend, support mlflow, weave backend: null @@ -386,3 +398,6 @@ quantization_config_file: null # MTP configuration, reuse model configuration mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + +# QAT configuration (inherited from actor.qat) +qat: ${oc.select:actor_rollout_ref.actor.qat,null} diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml deleted file mode 100644 index b2308e39e44..00000000000 --- a/verl/trainer/config/sft_trainer.yaml +++ /dev/null @@ -1,91 +0,0 @@ -defaults: - - optim: fsdp - - _self_ - -data: - train_batch_size: 256 - micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu - micro_batch_size_per_gpu: 4 # this is also val batch size - train_files: ~/data/gsm8k/train.parquet - val_files: ~/data/gsm8k/test.parquet - train_max_samples: -1 # set to -1 to use full dataset - val_max_samples: -1 # set to -1 to use full dataset - # Single-turn settings - prompt_key: question - response_key: answer - prompt_dict_keys: null - response_dict_keys: null - # Multi-turn settings - multiturn: - enable: false # Set to true to use multi-turn dataset - messages_key: messages # Key for messages list in multi-turn mode - tools_key: tools # Key for tools list in multi-turn mode - enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode - max_length: 1024 - truncation: error - balance_dp_token: False - chat_template: null - custom_cls: - path: null - name: null - use_shm: False - apply_chat_template_kwargs: {} -model: - partial_pretrain: ~/models/gemma-1.1-7b-it - use_shm: False - fsdp_config: - model_dtype: fp32 - wrap_policy: - min_num_params: 0 - cpu_offload: False - offload_params: False - external_lib: null - enable_gradient_checkpointing: True - trust_remote_code: False - lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) - lora_alpha: 16 # LoRA scaling factor - target_modules: all-linear # Target modules for LoRA adaptation - use_liger: False - strategy: fsdp2 -optim: - lr: 1e-5 - betas: [0.9, 0.95] - weight_decay: 0.01 - lr_warmup_steps_ratio: 0.1 - clip_grad: 1.0 - lr_scheduler: cosine -ulysses_sequence_parallel_size: 1 -use_remove_padding: False -trainer: - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} - default_hdfs_dir: null - project_name: gsm8k-sft - experiment_name: test - total_epochs: 4 - total_training_steps: null - logger: [ 'console', 'wandb' ] - seed: 1 - save_freq: -1 - test_freq: -1 - nnodes: 1 - n_gpus_per_node: 8 - max_ckpt_to_keep: null # Maximum number of checkpoints to keep, set to null to keep all - - # Resume mode: "auto", "disable", or "resume_path" - # "auto": resume from last checkpoint if available - # "disable": start from scratch - # "resume_path": resume from a user-defined path - resume_mode: auto - - # Path to resume training from (used when resume_mode is "resume_path" or "auto") - resume_from_path: null - - # Checkpoint configuration - checkpoint: - # What to include in saved checkpoints - # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - save_contents: ["model", "optimizer", "extra"] - - # For more flexibility, you can specify the contents to load from the checkpoint. - load_contents: ${trainer.checkpoint.save_contents} - device: cuda diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py deleted file mode 100644 index e8cc00e16c0..00000000000 --- a/verl/trainer/fsdp_sft_trainer.py +++ /dev/null @@ -1,873 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -A lightweight one-file FSDP SFT Trainer -TODO(zhangchi.usc1992) -- Add calculation of mfu -- Add validation -""" - -import os - -os.environ["NCCL_DEBUG"] = "WARN" -os.environ["TOKENIZERS_PARALLELISM"] = "true" - -import logging -import re -import time -from contextlib import nullcontext - -import hydra -import torch -import torch.distributed -from omegaconf import DictConfig, OmegaConf -from peft import LoraConfig, TaskType, get_peft_model -from tensordict import TensorDict -from torch import nn -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.utils.data import Dataset, DistributedSampler -from torchdata.stateful_dataloader import StatefulDataLoader -from tqdm import tqdm -from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel - -import verl.utils.hdfs_io as hdfs_io -from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input -from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, get_checkpoint_tracker_filename -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.dataset import SFTDataset -from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset -from verl.utils.device import ( - auto_set_device, - get_device_id, - get_device_name, - is_cuda_available, - is_npu_available, -) -from verl.utils.distributed import destroy_global_process_group, initialize_global_process_group -from verl.utils.fs import copy_to_local -from verl.utils.fsdp_utils import ( - CPUOffloadPolicy, - MixedPrecisionPolicy, - apply_fsdp2, - fsdp2_clip_grad_norm_, - fsdp2_load_full_state_dict, - get_fsdp_wrap_policy, - get_init_weight_context_manager, - init_fn, -) -from verl.utils.logger import log_with_rank -from verl.utils.profiler import log_gpu_memory_usage -from verl.utils.py_functional import convert_to_regular_types -from verl.utils.torch_dtypes import PrecisionType -from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup -from verl.utils.tracking import Tracking -from verl.utils.ulysses import ( - gather_outputs_and_unpad, - get_ulysses_sequence_parallel_world_size, - ulysses_pad_and_slice_inputs, -) -from verl.workers.config.optimizer import build_optimizer -from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) - - -def extract_step(path): - match = re.search(r"global_step_(\d+)", path) - if match: - return int(match.group(1)) - return None - - -class FSDPSFTTrainer: - def __init__( - self, - config, - device_mesh: DeviceMesh, - ulysses_device_mesh: DeviceMesh, - tokenizer, - train_dataset: Dataset, - val_dataset: Dataset, - ): - self.config = config - self.device_mesh = device_mesh - self.ulysses_device_mesh = ulysses_device_mesh - self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - self.tokenizer = tokenizer - if self.config.data.chat_template is not None: - raise ValueError("Apply Chat template from config is not supported yet.") - - # normalize dp size - self._normalize_config_bsz() - - # Set sequence parallel size - self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1) - self.use_remove_padding = getattr(self.config, "use_remove_padding", False) - if self.device_mesh.get_rank() == 0: - print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}") - print(f"Using remove padding: {self.use_remove_padding}") - - self._build_dataloader(train_dataset, val_dataset) - - self.lora = self.config.model.get("lora_adapter_path") is not None or self.config.model.lora_rank > 0 - - # Initialize resume-related variables - self.resume_global_step = 0 - - # build model - self._build_model_optimizer() - - # Initialize checkpoint manager - self._init_checkpoint_manager() - - self.load_checkpoint() - - if self.device_mesh.get_rank() == 0: - print(self.config) - - self.device_name = self.config.trainer.device - - def _normalize_config_bsz(self): - dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) - if self.device_mesh.get_rank() == 0: - print(f"Normalize batch size by dp {dp_size}") - - assert self.config.data.train_batch_size % dp_size == 0, ( - f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" - ) - - self.config.data.train_batch_size //= dp_size - - assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0 - - def _build_dataloader(self, train_dataset, val_dataset): - # build dataset - config = self.config - self.train_dataset, self.val_dataset = train_dataset, val_dataset - - # build dataloader - # Use data parallel rank and size instead of global rank and world size - - # If doing SP, we need to use the local rank and size - if self.config.ulysses_sequence_parallel_size > 1: - rank = self.ulysses_device_mesh.get_local_rank("dp") - world_size = self.ulysses_device_mesh.size(0) - if self.ulysses_device_mesh.get_rank() == 0: - print(f"Using SP rank {rank} and size {world_size} for data distribution") - print("Each SP rank gets different data, but the same data WITHIN the same rank") - else: - rank = self.device_mesh.get_rank() - world_size = self.device_mesh.size() - if self.device_mesh.get_rank() == 0: - print(f"Using FSDP rank {rank} and size {world_size} for data distribution") - - # Set pin_memory_device when pin_memory is enabled. - device_name = get_device_name() - - self.train_sampler = DistributedSampler( - self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True - ) - self.train_dataloader = StatefulDataLoader( - dataset=self.train_dataset, - batch_size=config.data.train_batch_size, - sampler=self.train_sampler, - num_workers=8, - pin_memory=True, - drop_last=True, - pin_memory_device=device_name, - ) - - self.val_sampler = DistributedSampler( - self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True - ) - self.val_dataloader = StatefulDataLoader( - dataset=self.val_dataset, - batch_size=config.data.micro_batch_size_per_gpu, - sampler=self.val_sampler, - num_workers=8, - pin_memory=True, - drop_last=True, - pin_memory_device=device_name, - ) - - def _build_model_optimizer(self): - # TODO (zhangchi.usc1992): - # 1. support pretrain from random weights - # 2. support init directly from sharded weights - local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) - - if self.config.model.get("external_lib", None) is not None: - # This is used to import external_lib into the huggingface systems - import importlib - - importlib.import_module(self.config.model.external_lib) - - log_gpu_memory_usage("Before model allocation", logger=logger) - - trust_remote_code = self.config.model.trust_remote_code - torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") - torch_dtype = PrecisionType.to_dtype(torch_dtype) - # load config first - config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) - self.model_config = config - if hasattr(self.model_config, "max_position_embeddings"): - self.model_config.max_position_embeddings = max( - self.model_config.max_position_embeddings, self.config.data.max_length - ) - if self.config.ulysses_sequence_parallel_size > 1: - assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" - - # This may be very large - init_context = get_init_weight_context_manager( - use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh - ) - - with init_context(): - self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - local_model_path, - config=config, - torch_dtype=torch_dtype, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, - ) - - if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1: - from verl.models.transformers.monkey_patch import apply_monkey_patch - - apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size) - - # Apply Liger kernel if use_liger is enabled - if self.config.model.get("use_liger", False): - from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance - - _apply_liger_kernel_to_instance(model=self.model) - - if self.lora: - self.model.enable_input_require_grads() - - lora_adapter_path = self.config.model.get("lora_adapter_path") - if lora_adapter_path is not None: - from peft import PeftModel - - print(f"Loading pre-trained LoRA adapter for sft from: {lora_adapter_path}") - - local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.use_shm) - - self.model = PeftModel.from_pretrained(self.model, local_adapter_path, is_trainable=True) - peft_config = self.model.peft_config["default"] - # Ensure task_type is TaskType enum, not string - if isinstance(peft_config.task_type, str): - peft_config.task_type = TaskType.CAUSAL_LM - else: - # Convert config to regular Python types before creating PEFT model - lora_config = { - "task_type": TaskType.CAUSAL_LM, - "r": self.config.model.lora_rank, - "lora_alpha": self.config.model.lora_alpha, - "target_modules": convert_to_regular_types(self.config.model.target_modules), - "bias": "none", - } - self.model = get_peft_model(self.model, LoraConfig(**lora_config)) - self.model = self.model.to(torch_dtype) - - if self.config.model.enable_gradient_checkpointing: - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - - log_gpu_memory_usage("After model allocation", logger=logger) - - mixed_precision = MixedPrecision( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 - ) - - auto_wrap_policy = get_fsdp_wrap_policy( - self.model, - config=self.config.model.fsdp_config.wrap_policy, - is_lora=self.lora, - ) - - if self.device_mesh.get_rank() == 0: - print(auto_wrap_policy) - - if not self.config.model.fsdp_config.cpu_offload: - cpu_offload = None - else: - cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params) - - fsdp_strategy = self.config.model.strategy - if fsdp_strategy == "fsdp": - self.fsdp_model = FSDP( - self.model, - cpu_offload=cpu_offload, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=get_device_id(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - sync_module_states=True, - device_mesh=self.device_mesh, - forward_prefetch=False, - ) - elif fsdp_strategy == "fsdp2": - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" - mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True - ) - - fsdp_kwargs = { - "mesh": self.device_mesh, - "mp_policy": mp_policy, - "offload_policy": cpu_offload, - "reshard_after_forward": True, - } - full_state = self.model.state_dict() - apply_fsdp2(self.model, fsdp_kwargs, self.config.model.fsdp_config) - fsdp2_load_full_state_dict(self.model, full_state, self.device_mesh, cpu_offload) - self.fsdp_model = self.model - else: - raise NotImplementedError(f"not implement {fsdp_strategy}") - - log_gpu_memory_usage("After FSDP wrapping", logger=logger) - - self.optimizer = build_optimizer(self.fsdp_model.parameters(), self.config.optim) - - log_gpu_memory_usage("After initialize optimizer", logger=logger) - - self.steps_per_epoch = len(self.train_dataloader) - self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs - - if self.device_mesh.get_rank() == 0: - print( - f"Number of steps/epoch {self.steps_per_epoch}, number of epochs " - f"{self.config.trainer.total_epochs}, total number of steps {self.total_steps}" - ) - - num_warmup_steps = int(self.total_steps * self.config.optim.lr_warmup_steps_ratio) - - if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine": - self.lr_scheduler = get_cosine_schedule_with_warmup( - optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps - ) - elif self.config.optim.lr_scheduler == "wsd": - self.lr_scheduler = get_wsd_schedule_with_warmup( - optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps - ) - else: - raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}") - - def _compute_loss_and_backward(self, batch, do_backward=True, n_micro_batches=1): - """Compute loss with optional sequence parallelism and remove padding features""" - use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 - - # Move inputs to GPU and prepare loss mask - input_ids = batch["input_ids"].to(self.device_name) - attention_mask = batch["attention_mask"].to(self.device_name) - position_ids = batch["position_ids"].to(self.device_name) - loss_mask = batch.pop("loss_mask")[:, 1:].reshape(-1).to(self.device_name) - loss_fct = nn.CrossEntropyLoss(reduction="none") - - # Context manager for sequence parallel if needed - context = self.sharding_manager if use_sp else nullcontext() - with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): - if not use_sp: - # Standard forward pass without sequence parallel - labels = input_ids[:, 1:].contiguous() - output = self.fsdp_model( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False - ) - logits = output.logits - - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels.contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, self.model.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - loss = loss * loss_mask.to(loss.device) - else: - # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks - # i.e., each GPU has <1 sequence, and each SP group has 1 sequence - # 1. All SP ranks will receive the *SAME* batch - # 2. Different SP groups will receive *DIFFERENT* batches - # This is implemented by the DistributedSampler - - batch_size, seqlen = input_ids.shape - # Remove padding - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # Unpad position_ids to align rotary - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # Pad and slice inputs for sequence parallelism - input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() - ) - # For computing loss - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size() - ) - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) - - # Forward pass - output = self.fsdp_model( - input_ids=input_ids_rmpad_sliced, - attention_mask=None, # Not needed with flash attention varlen - position_ids=position_ids_rmpad_padded, - use_cache=False, - ) - - # Compute loss locally then aggregate - logits_rmpad = output.logits.squeeze(0) - input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device) - loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) - # Gather and unpad for sequence parallelism - loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) - - # This is the loss collected from all ulysses ranks - full_loss = pad_input( - hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen - ) - full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss - full_loss = full_loss.reshape(-1) - loss_mask = loss_mask.to(full_loss.device) - loss = full_loss * loss_mask - - valid_token_this_rank = torch.sum(loss_mask) - - if self.config.data.balance_dp_token: - torch.distributed.all_reduce(valid_token_this_rank) - dp_size = self.ulysses_device_mesh.size("dp") if use_sp else torch.distributed.get_world_size() - else: - dp_size = 1 - - loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size - - loss = loss / n_micro_batches # normalize loss - - if do_backward: - loss.backward() - return loss - - def training_step(self, batch: TensorDict): - start_time = time.time() - - self.fsdp_model.train() - - log_gpu_memory_usage("Before optimizer zero_grad", logger=logger) - - self.optimizer.zero_grad() - - log_gpu_memory_usage("After optimizer zero_grad", logger=logger) - - micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu) - n_micro_batches = len(micro_batches) - step_loss = 0 - for micro_batch in micro_batches: - loss = self._compute_loss_and_backward(batch=micro_batch, n_micro_batches=n_micro_batches) - step_loss += loss.item() - - if self.config.model.strategy == "fsdp": - grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) - elif self.config.model.strategy == "fsdp2": - grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad) - else: - raise NotImplementedError(f"not implement {self.config.model.strategy}") - - log_gpu_memory_usage("Before optimizer step", logger=logger) - - # if grad_norm is not finite, skip the update - if not torch.isfinite(grad_norm): - print(f"WARN: grad_norm is not finite: {grad_norm}") - self.optimizer.zero_grad() - else: - self.optimizer.step() - - log_gpu_memory_usage("After optimizer step", logger=logger) - - self.lr_scheduler.step() - - # reduce loss across dp ranks - lr = self.lr_scheduler.get_last_lr()[0] - - log_gpu_memory_usage("After offload weights", logger=logger) - - step_loss = torch.tensor(step_loss).to(self.device_name) - - # compute time spent per step - end_time = time.time() - spend_time_per_step = end_time - start_time - - if is_cuda_available: - torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) - elif is_npu_available: - torch.distributed.all_reduce(step_loss) - step_loss /= self.device_mesh.size(0) - return { - "train/loss": step_loss.detach().item(), - "train/lr(1e-3)": lr * 1e3, - "train/time(s)": spend_time_per_step, - } - - def validation_step(self, batch: TensorDict): - self.fsdp_model.eval() - with torch.no_grad(): - loss = self._compute_loss_and_backward(batch, do_backward=False) - if is_cuda_available: - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) - elif is_npu_available: - torch.distributed.all_reduce(loss) - loss /= self.device_mesh.size(0) - return loss - - def save_checkpoint(self, step): - """Save checkpoint using FSDPCheckpointManager with improved tracking""" - from verl.utils.fs import local_mkdir_safe - - # Determine checkpoint path - local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}") - - if self.device_mesh.get_rank() == 0: - print(f"Saving checkpoint to: {local_global_step_folder}") - - # Get max checkpoints to keep - max_ckpt_to_keep = getattr(self.config.trainer, "max_ckpt_to_keep", None) - - # Use checkpoint manager to save - self.checkpoint_manager.save_checkpoint( - local_path=local_global_step_folder, global_step=step, max_ckpt_to_keep=max_ckpt_to_keep - ) - - # Save dataloader state - if self.device_mesh.get_rank() == 0: - local_mkdir_safe(local_global_step_folder) - dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") - - # Use StatefulDataLoader's built-in state dict functionality - dataloader_state_dict = self.train_dataloader.state_dict() - torch.save(dataloader_state_dict, dataloader_local_path) - print(f"Saved dataloader state to: {dataloader_local_path}") - - # Update latest checkpoint tracker (atomic write) - tracker_file = get_checkpoint_tracker_filename(self.config.trainer.default_local_dir) - temp_tracker_file = tracker_file + ".tmp" - with open(temp_tracker_file, "w") as f: - f.write(str(step)) - os.rename(temp_tracker_file, tracker_file) - print(f"Updated checkpoint tracker: {tracker_file}") - - # Copy to HDFS if configured - if self.device_mesh.get_rank() == 0 and getattr(self.config.trainer, "default_hdfs_dir", None): - hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) - hdfs_io.copy(src=local_global_step_folder, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True) - - torch.distributed.barrier() - - def _init_checkpoint_manager(self): - """Initialize checkpoint manager with proper configuration""" - # Get checkpoint configuration from config, with defaults - checkpoint_config = getattr(self.config.trainer, "checkpoint", {}) - - # Set default values if not specified - save_contents = checkpoint_config.get("save_contents", ["model", "optimizer", "extra"]) - load_contents = checkpoint_config.get("load_contents", save_contents) - - # Create checkpoint config dict - checkpoint_config_dict = { - "load_contents": load_contents, - "save_contents": save_contents, - } - - # Convert to DictConfig for compatibility - checkpoint_config_dict = DictConfig(checkpoint_config_dict) - - # Initialize checkpoint manager - self.checkpoint_manager = FSDPCheckpointManager( - model=self.fsdp_model, - optimizer=self.optimizer, - lr_scheduler=self.lr_scheduler, - processing_class=self.tokenizer, - checkpoint_config=checkpoint_config_dict, - trust_remote_code=self.config.model.trust_remote_code, - ) - - def load_checkpoint(self): - # Determine resume path based on configuration - checkpoint_path = self._determine_resume_path() - - if checkpoint_path is None: - return 0 - - # extract resume step from checkpoint path - resume_step = extract_step(checkpoint_path) - if resume_step is None: - log_with_rank( - f"Warning: Could not extract step number from {checkpoint_path}, starting from step 0", - logger=logger, - rank=self.device_mesh.get_rank(), - level=logging.WARNING, - log_only_rank_0=True, - ) - return 0 - self.resume_global_step = resume_step - - # Use checkpoint manager to load model state - self.checkpoint_manager.load_checkpoint(checkpoint_path) - log_with_rank( - f"Successfully loaded model checkpoint from {checkpoint_path} (step {resume_step})", - logger=logger, - rank=self.device_mesh.get_rank(), - log_only_rank_0=True, - ) - - # Always load dataloader state for StatefulDataLoader - self._load_dataloader_state(checkpoint_path) - - return resume_step - - def _load_dataloader_state(self, checkpoint_path: str): - """Load dataloader state from checkpoint""" - dataloader_path = os.path.join(checkpoint_path, "data.pt") - - if os.path.exists(dataloader_path): - # Use StatefulDataLoader's built-in state dict functionality - dataloader_state_dict = torch.load(dataloader_path, map_location="cpu", weights_only=False) - self.train_dataloader.load_state_dict(dataloader_state_dict) - - log_with_rank( - f"Successfully loaded dataloader state from {dataloader_path}", - logger=logger, - rank=self.device_mesh.get_rank(), - log_only_rank_0=True, - ) - - else: - log_with_rank( - f"Warning: No dataloader state found at {dataloader_path}, will start from scratch", - logger=logger, - rank=self.device_mesh.get_rank(), - level=logging.WARNING, - log_only_rank_0=True, - ) - - def _determine_resume_path(self): - """Determine the path to resume from based on resume_mode configuration""" - resume_mode = getattr(self.config.trainer, "resume_mode", "auto") - resume_from_path = getattr(self.config.trainer, "resume_from_path", None) - - if resume_mode == "disable": - return None - elif resume_mode == "auto": - if resume_from_path is not None: - assert os.path.exists(resume_from_path), ( - "resume_from_path must be null or an existing path when resume_mode is 'auto'" - ) - assert "global_step_" in resume_from_path, "resume_from_path must specify the global_steps" - return resume_from_path - # Try to find the latest checkpoint in the default directory - return self._find_latest_checkpoint() - elif resume_mode == "resume_path": - assert os.path.exists(resume_from_path), ( - "resume_from_path must be an existing path when resume_mode is 'resume_path'" - ) - assert "global_step_" in resume_from_path, "resume_from_path must specify the global_steps" - return resume_from_path - else: - raise ValueError(f"Invalid resume_mode: {resume_mode}. Must be 'auto', 'disable', or 'resume_path'") - - def _find_latest_checkpoint(self): - """Find the latest checkpoint in the default local directory""" - checkpoint_dir = self.config.trainer.default_local_dir - - if not os.path.exists(checkpoint_dir): - return None - - latest_checkpoint = find_latest_ckpt_path(checkpoint_dir) - - if latest_checkpoint and self.device_mesh.get_rank() == 0: - step_num = extract_step(latest_checkpoint) - print(f"Found latest checkpoint: {latest_checkpoint} (step {step_num})") - - return latest_checkpoint - - def fit(self): - rank = self.device_mesh.get_rank() - - # TODO: add a unified tracking - if rank == 0: - tracking = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - - global_step = self.resume_global_step # Start from resumed step - last_valid_metric = None - # compute the total training steps. - # the total training steps in SFT is mainly for early exit - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs - - if self.config.trainer.total_training_steps is not None: - total_training_steps = self.config.trainer.total_training_steps - - self.total_training_steps = total_training_steps - log_with_rank( - f"Total training steps: {self.total_training_steps},", - logger=logger, - rank=self.device_mesh.get_rank(), - log_only_rank_0=True, - ) - - # With StatefulDataLoader, we don't need to manually calculate epochs and steps - # The dataloader will automatically resume from where it left off - if global_step > 0: - log_with_rank( - f"StatefulDataLoader will automatically resume from global step: {global_step}", - logger=logger, - rank=self.device_mesh.get_rank(), - log_only_rank_0=True, - ) - - # Calculate which epoch we're starting from for sampler.set_epoch() - start_epoch = global_step // self.steps_per_epoch - - train_time = 0 - for epoch in range(start_epoch, self.config.trainer.total_epochs): - self.train_sampler.set_epoch(epoch=epoch) - - for step_in_epoch, data in enumerate( - tqdm( - self.train_dataloader, - initial=global_step % self.steps_per_epoch if epoch == start_epoch else 0, - total=self.steps_per_epoch, - desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", - disable=rank != 0, - ) - ): - global_step += 1 - data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name) - metric = self.training_step(data) - train_time += metric["train/time(s)"] - if rank == 0: - tracking.log(data=metric, step=global_step) - - is_last_step = global_step >= self.total_training_steps - is_valid_step = global_step % self.config.trainer.test_freq == 0 - is_save_step = global_step % self.config.trainer.save_freq == 0 - - # early exit or validation step - if is_last_step or (self.config.trainer.test_freq > 0 and is_valid_step): - # Perform validation - val_losses = [] - for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to( - self.device_name - ) - val_loss = self.validation_step(val_data) - val_losses.append(val_loss) - if rank == 0: - val_loss = torch.mean(torch.stack(val_losses)) - metric = {"val/loss": val_loss.detach().item()} - tracking.log(data=metric, step=global_step) - last_valid_metric = metric - torch.distributed.barrier() - - if is_last_step or (self.config.trainer.save_freq > 0 and is_save_step): - self.save_checkpoint(step=global_step) - - if is_last_step: - if rank == 0: - print(f"Total time for train steps: {train_time:.2f}s") - print(f"Final validation metrics: {last_valid_metric}") - return - - -def run_sft(config): - device_name = get_device_name() - local_rank, rank, world_size = initialize_global_process_group() - - device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) - dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh( - device_type=device_name, - mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), - mesh_dim_names=("dp", "sp"), - ) - # build tokenizer and datasets first - from verl.utils import hf_tokenizer - - local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) - tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) - train_dataset = create_sft_dataset( - config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1) - ) - val_dataset = create_sft_dataset( - config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1) - ) - - trainer = FSDPSFTTrainer( - config=config, - device_mesh=device_mesh, - ulysses_device_mesh=ulysses_device_mesh, - tokenizer=tokenizer, - train_dataset=train_dataset, - val_dataset=val_dataset, - ) - - trainer.fit() - - destroy_global_process_group() - - -@hydra.main(config_path="config", config_name="sft_trainer", version_base=None) -def main(config): - # Automatically set `config.trainer.device = npu` when running on Ascend NPU. - auto_set_device(config) - - run_sft(config) - - -def create_sft_dataset(data_paths, data_config, tokenizer, max_samples=-1): - """Create a dataset.""" - # build dataset - # First check if a custom dataset class is specified - if data_config.custom_cls.get("path", None): - from verl.utils.import_utils import load_extern_object - - dataset_cls = load_extern_object(data_config.custom_cls.path, data_config.custom_cls.name) - # Then check if multi-turn dataset should be used - elif data_config.get("multiturn", {}).get("enable", False): - dataset_cls = MultiTurnSFTDataset - # Default to single-turn dataset - else: - dataset_cls = SFTDataset - - # Create datasets based on the selected class - dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config, max_samples=max_samples) - return dataset - - -if __name__ == "__main__": - main() diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py deleted file mode 100644 index 18aaa8cdbd0..00000000000 --- a/verl/trainer/main_generation.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Generate responses given a dataset of prompts -""" - -import os - -import hydra -import numpy as np -import ray - -os.environ["NCCL_DEBUG"] = "WARN" -os.environ["TOKENIZERS_PARALLELISM"] = "true" -# os.environ['TORCH_COMPILE_DISABLE'] = '1' - -from pprint import pprint - -import pandas as pd -from omegaconf import OmegaConf - -from verl import DataProto -from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup -from verl.utils import hf_tokenizer -from verl.utils.fs import copy_to_local -from verl.utils.hdfs_io import makedirs -from verl.utils.model import compute_position_id_with_mask -from verl.workers.fsdp_workers import ActorRolloutRefWorker - - -@hydra.main(config_path="config", config_name="generation", version_base=None) -def main(config): - run_generation(config) - - -def run_generation(config) -> None: - if not ray.is_initialized(): - # this is for local ray cluster - default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}} - ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) - runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) - runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) - ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) - print(f"ray init kwargs: {ray_init_kwargs}") - ray.init(**OmegaConf.to_container(ray_init_kwargs)) - - ray.get(main_task.remote(config)) - - -@ray.remote(num_cpus=1) -def main_task(config): - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - OmegaConf.resolve(config) - - local_path = copy_to_local(config.model.path) - trust_remote_code = config.data.get("trust_remote_code", False) - tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - - if config.rollout.temperature == 0.0: - assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1." - assert config.data.n_samples >= 1, "n_samples should always >= 1" - - # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary) - dataset = pd.read_parquet(config.data.path) - chat_lst = dataset[config.data.prompt_key].tolist() - - chat_lst = [chat.tolist() for chat in chat_lst] - - tokenizer.padding_side = "left" - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") - resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) - - wg = RayWorkerGroup( - resource_pool=resource_pool, - ray_cls_with_init=ray_cls_with_init, - device_name=config.trainer.device, - ) - wg.init_model() - - total_samples = len(dataset) - config_batch_size = config.data.batch_size - apply_chat_template_kwargs = config.data.get("apply_chat_template_kwargs", {}) - num_batch = -(-total_samples // config_batch_size) - output_lst = [[] for _ in range(config.data.n_samples)] - - for batch_idx in range(num_batch): - print(f"[{batch_idx + 1}/{num_batch}] Start to process.") - batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size] - inputs = tokenizer.apply_chat_template( - batch_chat_lst, - add_generation_prompt=True, - padding=True, - truncation=True, - max_length=config.rollout.prompt_length, - return_tensors="pt", - return_dict=True, - tokenize=True, - **apply_chat_template_kwargs, - ) - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - position_ids = compute_position_id_with_mask(attention_mask) - batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} - - data = DataProto.from_dict(batch_dict) - data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size) - - # START TO GENERATE FOR n_samples TIMES - print(f"[{batch_idx + 1}/{num_batch}] Start to generate.") - for n_sample in range(config.data.n_samples): - output_padded = wg.generate_sequences(data_padded) - output = unpad_dataproto(output_padded, pad_size=pad_size) - - output_texts = [] - for i in range(len(output)): - data_item = output[i] - prompt_length = data_item.batch["prompts"].shape[-1] - valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() - valid_response_ids = data_item.batch["responses"][:valid_response_length] - response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True) - output_texts.append(response_str) - - output_lst[n_sample].extend(output_texts) - - # convert output_lst from (n_samples, n_data) to (n_data, n_sampels) - output_lst = np.array(output_lst, dtype=object) - output_lst = np.transpose(output_lst, axes=(1, 0)).tolist() - - # add to the data frame - dataset["responses"] = output_lst - - # write to a new parquet - output_dir = os.path.dirname(config.data.output_path) - makedirs(output_dir, exist_ok=True) - dataset.to_parquet(config.data.output_path) - - -if __name__ == "__main__": - main() diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index b82414a2880..2c84374d245 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -162,8 +162,13 @@ def add_actor_rollout_worker(self, config): actor_rollout_cls = AsyncActorRolloutRefWorker ray_worker_group_cls = RayWorkerGroup - elif config.actor_rollout_ref.actor.strategy == "veomni": - raise NotImplementedError("VeOmni does not support legacy worker implementation") + elif ( + config.actor_rollout_ref.actor.strategy == "veomni" + or config.actor_rollout_ref.actor.strategy == "torchtitan" + ): + raise NotImplementedError( + f"{config.actor_rollout_ref.actor.strategy} does not support legacy worker implementation" + ) else: raise NotImplementedError @@ -191,14 +196,16 @@ def add_critic_worker(self, config): # TODO: switch this to TrainingWorker as well from verl.workers.megatron_workers import CriticWorker - elif config.critic.strategy == "veomni": + elif config.critic.strategy == "veomni" or config.critic.strategy == "torchtitan": if use_legacy_worker_impl == "disable": from verl.workers.engine_workers import TrainingWorker CriticWorker = TrainingWorker - print("Using new worker implementation") + print(f"Using new worker implementation for {config.critic.strategy}") else: - raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + raise ValueError( + f"Invalid use_legacy_worker_impl for {config.critic.strategy}: {use_legacy_worker_impl}" + ) else: raise NotImplementedError diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 2039fe56f62..b222dc1b705 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -763,7 +763,7 @@ def compute_optimal_token_baseline_advantage( old_log_probs: torch.Tensor, sum_pi_squared: torch.Tensor, rollout_is_weights: torch.Tensor = None, - handle_zero_tail: bool = False, + handle_zero_tail: bool = True, epsilon: float = 1e-8, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -791,7 +791,7 @@ def compute_optimal_token_baseline_advantage( None if not using IS handle_zero_tail: If True, zero baselines will be set in the portion of the longest trajectory that extends beyond the second-longest trajectory in the prompt group. - Default: False + Default: True epsilon: Small constant for numerical stability (default: 1e-8) Returns: @@ -1054,26 +1054,32 @@ def agg_loss( """ if loss_agg_mode == "token-mean": if batch_num_tokens is None: + if dp_size > 1: + raise ValueError("(global) batch_num_tokens is required when dp_size > 1") batch_num_tokens = loss_mask.sum() loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size - elif loss_agg_mode == "seq-mean-token-sum": + elif loss_agg_mode in ["seq-mean-token-sum", "seq-mean-token-sum-norm"]: seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences if global_batch_size is None: + if dp_size > 1: + raise ValueError("global_batch_size is required when dp_size > 1") global_batch_size = seq_mask.sum() loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean + if loss_agg_mode == "seq-mean-token-sum-norm": + if loss_scale_factor is None: + horizon = loss_mask.shape[-1] + loss_scale_factor = horizon + loss /= loss_scale_factor elif loss_agg_mode == "seq-mean-token-mean": seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean seq_mask = (seq_mask > 0).float() # exclude fully masked sequences if global_batch_size is None: + if dp_size > 1: + raise ValueError("global_batch_size is required when dp_size > 1") global_batch_size = seq_mask.sum() loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean - elif loss_agg_mode == "seq-mean-token-sum-norm": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) - if loss_scale_factor is None: - loss_scale_factor = loss_mask.shape[-1] - loss = torch.sum(seq_losses) / loss_scale_factor else: raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") @@ -1250,6 +1256,172 @@ def compute_policy_loss_vanilla( return pg_loss, pg_metrics +@register_policy_loss("dppo_tv") +def compute_policy_loss_dppo_tv( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute the clipped policy objective and related metrics for DPPO-Binary-TV. + + See https://arxiv.org/pdf/2602.04879 for more details. + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + config: `(verl.trainer.config.ActorConfig)`: + config for the actor. + rollout_log_probs: `(torch.Tensor)`: + log probabilities of actions under the rollout policy, shape (batch_size, response_length). + """ + + assert config is not None + assert not isinstance(config, AlgoConfig) + # Note: the clip_ratio is different from the standard PPO, it is the TV divergence threshold for DPPO. + clip_divergence = config.clip_ratio + clip_divergence_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_divergence + clip_divergence_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_divergence + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + # Instead of dual-clip PPO, we use truncated importance sampling (TIS) to clip the policy loss. + # However, a large threshold is recommended to avoid performance degradation due to the truncation bias. + # See Section 5.4 in https://arxiv.org/pdf/2602.04879 for more details. + clip_ratio_c = config.get("clip_ratio_c", 20.0) + truncated_ratio = torch.clamp(ratio, max=clip_ratio_c) + truncated_ratio = truncated_ratio.detach() + + # Compute valid mask for DPPO-Binary-TV + prob = torch.exp(log_prob) + old_prob = torch.exp(old_log_prob) + valid_positive_mask = (prob - old_prob) <= clip_divergence_high + valid_negative_mask = (prob - old_prob) >= -clip_divergence_low + valid_mask = torch.where(advantages > 0, valid_positive_mask, valid_negative_mask) + valid_mask = valid_mask.detach().float() + + pg_losses = -advantages * truncated_ratio * log_prob * valid_mask + + # Apply rollout correction weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info + ) + + pg_clipfrac = verl_F.masked_mean((1.0 - valid_mask).float(), response_mask) + pg_clipfrac_lower = verl_F.masked_mean((ratio > clip_ratio_c).float() * valid_mask, response_mask) + + pg_metrics = { + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + return pg_loss, pg_metrics + + +@register_policy_loss("dppo_kl") +def compute_policy_loss_dppo_kl( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute the clipped policy objective and related metrics for DPPO-Binary-KL. + + See https://arxiv.org/pdf/2602.04879 for more details. + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + config: `(verl.trainer.config.ActorConfig)`: + config for the actor. + rollout_log_probs: `(torch.Tensor)`: + log probabilities of actions under the rollout policy, shape (batch_size, response_length). + """ + + assert config is not None + assert not isinstance(config, AlgoConfig) + # Note: the clip_ratio is different from the standard PPO, it is the KL divergence threshold for DPPO. + clip_divergence = config.clip_ratio + clip_divergence_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_divergence + clip_divergence_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_divergence + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + # Instead of dual-clip PPO, we use truncated importance sampling (TIS) to clip the policy loss. + # However, a large threshold is recommended to avoid performance degradation due to the truncation bias. + # See Section 5.4 in https://arxiv.org/pdf/2602.04879 for more details. + clip_ratio_c = config.get("clip_ratio_c", 20.0) + truncated_ratio = torch.clamp(ratio, max=clip_ratio_c) + truncated_ratio = truncated_ratio.detach() + + # Compute valid mask for DPPO-Binary-KL + prob = torch.exp(log_prob) + old_prob = torch.exp(old_log_prob) + binary_kl = old_prob * (old_log_prob - log_prob) + (1 - old_prob) * torch.log( + (1.0 - old_prob + 1e-8) / (1.0 - prob + 1e-8) + ) + valid_positive_mask = (binary_kl <= clip_divergence_high) | (prob <= old_prob) + valid_negative_mask = (binary_kl <= clip_divergence_low) | (prob >= old_prob) + valid_mask = torch.where(advantages > 0, valid_positive_mask, valid_negative_mask) + valid_mask = valid_mask.detach().float() + + pg_losses = -advantages * truncated_ratio * log_prob * valid_mask + + # Apply rollout correction weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info + ) + + # For compatibility, return zero for pg_clipfrac_lower (not used in standard DPPO) + pg_clipfrac = verl_F.masked_mean((1.0 - valid_mask).float(), response_mask) + pg_clipfrac_lower = verl_F.masked_mean((ratio > clip_ratio_c).float() * valid_mask, response_mask) + + pg_metrics = { + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + return pg_loss, pg_metrics + + @register_policy_loss("gspo") def compute_policy_loss_gspo( old_log_prob: torch.Tensor, diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 196e55969ce..ae43d2bad5c 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -303,6 +303,8 @@ def __init__( self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + self.checkpoint_manager = None + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): """ Creates the train and validation dataloaders. @@ -481,8 +483,7 @@ def _get_gen_batch(self, batch: DataProto) -> DataProto: ) # For agent loop, we need reward model keys to compute score. - if self.async_rollout_mode: - gen_batch.non_tensor_batch.update(batch.non_tensor_batch) + gen_batch.non_tensor_batch.update(batch.non_tensor_batch) return gen_batch @@ -536,16 +537,9 @@ def _validate(self, merged: bool = False): print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") # pad to be divisible by dp_size - size_divisor = ( - self.actor_rollout_wg.world_size - if not self.async_rollout_mode - else self.config.actor_rollout_ref.rollout.agent.num_workers - ) + size_divisor = self.config.actor_rollout_ref.rollout.agent.num_workers test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) - if not self.async_rollout_mode: - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) - else: - test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) if self.use_rm and "rm_scores" not in test_output_gen_batch_padded.batch.keys(): # for colocate reward models, we need to sleep rollout model @@ -765,6 +759,8 @@ def init_workers(self): wg_kwargs["device_name"] = self.device_name for resource_pool, class_dict in self.resource_pool_to_cls.items(): + if not class_dict: + continue worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = self.ray_worker_group_cls( resource_pool=resource_pool, @@ -835,15 +831,16 @@ def init_workers(self): # if enable_agent_reward_loop, we directly pass reward_loop_workers to agent loop manager # to stream reward computation with actor rollout reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None - self.async_rollout_manager = AgentLoopManager( + self.async_rollout_manager = AgentLoopManager.create( config=self.config, worker_group=self.actor_rollout_wg, rollout_resource_pool=actor_rollout_resource_pool, reward_loop_worker_handles=reward_loop_worker_handles, ) + checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine) self.checkpoint_manager = CheckpointEngineManager( - backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend, + config=checkpoint_engine_config, trainer=self.actor_rollout_wg, replicas=self.async_rollout_manager.rollout_replicas, ) @@ -1258,7 +1255,7 @@ def fit(self): return if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): - rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip = RolloutSkip(self.config, self.async_rollout_manager) rollout_skip.wrap_generate_sequences() # add tqdm @@ -1310,15 +1307,12 @@ def fit(self): with marked_timer("step", timing_raw): # generate a batch with marked_timer("gen", timing_raw, color="red"): - if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) - else: - if curr_step_profile: - self.async_rollout_manager.start_profile() - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) - self.checkpoint_manager.sleep_replicas() - if curr_step_profile: - self.async_rollout_manager.stop_profile() + if curr_step_profile: + self.async_rollout_manager.start_profile() + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + self.checkpoint_manager.sleep_replicas() + if curr_step_profile: + self.async_rollout_manager.stop_profile() timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) @@ -1327,15 +1321,12 @@ def fit(self): with marked_timer("gen_max", timing_raw, color="purple"): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False - if not self.async_rollout_mode: - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - else: - if curr_step_profile: - self.async_rollout_manager.start_profile() - gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) - self.checkpoint_manager.sleep_replicas() - if curr_step_profile: - self.async_rollout_manager.stop_profile() + if curr_step_profile: + self.async_rollout_manager.start_profile() + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + self.checkpoint_manager.sleep_replicas() + if curr_step_profile: + self.async_rollout_manager.stop_profile() batch = batch.union(gen_baseline_output) # compute reward model score on batch rm_scores = None diff --git a/verl/trainer/sft_trainer.py b/verl/trainer/sft_trainer.py index 979d92b04a1..d23ebc5fa90 100644 --- a/verl/trainer/sft_trainer.py +++ b/verl/trainer/sft_trainer.py @@ -238,8 +238,14 @@ def _get_batch_seqlens(self, data): batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1) batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp) + dp_group = self.engine.get_data_parallel_group() + dp_size = self.engine.get_data_parallel_size() + + if dp_size == 1 or dp_group is None: + return batch_seqlens.tolist() + output_tensor = torch.empty( - (batch_seqlens.shape[0] * self.engine.get_data_parallel_size(),), + (batch_seqlens.shape[0] * dp_size,), dtype=batch_seqlens.dtype, device=self.device_name, ) # (global_bsz,) @@ -247,7 +253,7 @@ def _get_batch_seqlens(self, data): torch.distributed.all_gather_into_tensor( output_tensor=output_tensor, input_tensor=batch_seqlens, - group=self.engine.get_data_parallel_group(), + group=dp_group, ) batch_seqlens = output_tensor.tolist() @@ -372,9 +378,9 @@ def fit(self): if self.engine.is_mp_src_rank_with_outputs(): val_loss = torch.mean(torch.tensor(val_losses, device=self.device_name)) # average over data parallel group - torch.distributed.all_reduce( - val_loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group() - ) + dp_group = self.engine.get_data_parallel_group() + if dp_group is not None: + torch.distributed.all_reduce(val_loss, op=torch.distributed.ReduceOp.AVG, group=dp_group) if is_logging: metric = {"val/loss": val_loss.detach().item()} diff --git a/verl/utils/dataset/__init__.py b/verl/utils/dataset/__init__.py index 6032d68c864..423bd9a9ccd 100644 --- a/verl/utils/dataset/__init__.py +++ b/verl/utils/dataset/__init__.py @@ -14,6 +14,5 @@ from .rl_dataset import RLHFDataset from .rm_dataset import RMDataset -from .sft_dataset import SFTDataset -__all__ = ["RLHFDataset", "RMDataset", "SFTDataset"] +__all__ = ["RLHFDataset", "RMDataset"] diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index 9da33228e21..081d1dcfafa 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -36,6 +36,7 @@ from verl.utils.dataset.dataset_utils import DatasetPadMode from verl.utils.dataset.vision_utils import process_image, process_video from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.py_functional import convert_nested_value_to_list_recursive logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -68,19 +69,6 @@ def print_assembled_message(tokenizer, message_list, input_ids, loss_mask, attn_ logger.debug(str) -def convert_nested_value_to_list_recursive(data_item): - if isinstance(data_item, dict): - return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()} - elif isinstance(data_item, list): - return [convert_nested_value_to_list_recursive(elem) for elem in data_item] - elif isinstance(data_item, np.ndarray): - # Convert to list, then recursively process the elements of the new list - return convert_nested_value_to_list_recursive(data_item.tolist()) - else: - # Base case: item is already a primitive type (int, str, float, bool, etc.) - return data_item - - class MultiTurnSFTDataset(Dataset): """ Dataset for multi-turn conversations where each assistant response should be trained diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py deleted file mode 100644 index 5fa8e07b252..00000000000 --- a/verl/utils/dataset/sft_dataset.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -SFT dataset -- We assume user pass a single parquet file. -- We load all the data into the memory. -Each parquet file contains -""" - -import numpy as np -import pandas as pd -import torch -from omegaconf.listconfig import ListConfig -from torch.utils.data import Dataset -from transformers import PreTrainedTokenizer - -from verl.utils import hf_tokenizer -from verl.utils.fs import copy_to_local -from verl.utils.model import compute_position_id_with_mask - - -class SFTDataset(Dataset): - """ - This is an in-memory SFTDataset - - Arguments: - config (OmegaConf): the data config - """ - - def __init__(self, parquet_files: str | ListConfig, tokenizer, config, max_samples: int = -1): - prompt_key = config.get("prompt_key", "prompt") - prompt_dict_keys = config.get("prompt_dict_keys", None) - response_key = config.get("response_key", "response") - response_dict_keys = config.get("response_dict_keys", None) - max_length = config.get("max_length", 1024) - truncation = config.get("truncation", "error") - use_shm = config.get("use_shm", False) - self.shuffle = config.get("shuffle", False) - self.seed = config.get("seed") - self.apply_chat_template_kwargs = config.get("apply_chat_template_kwargs", {}) - - assert truncation in ["error", "left", "right"] - self.truncation = truncation - self.use_shm = use_shm - - if not isinstance(parquet_files, ListConfig): - parquet_files = [parquet_files] - - self.parquet_files = parquet_files - self.max_samples = max_samples - if isinstance(tokenizer, str): - tokenizer = hf_tokenizer(tokenizer) - self.tokenizer: PreTrainedTokenizer = tokenizer - - self.prompt_key = prompt_key if isinstance(prompt_key, tuple | list) else [prompt_key] - self.response_key = response_key if isinstance(response_key, tuple | list) else [response_key] - self.prompt_dict_keys = prompt_dict_keys if prompt_dict_keys else [] - self.response_dict_keys = response_dict_keys if response_dict_keys else [] - - self.max_length = max_length - - self._download() - self._read_files_and_tokenize() - - def _download(self): - for i, parquet_file in enumerate(self.parquet_files): - self.parquet_files[i] = copy_to_local(parquet_file, verbose=True, use_shm=self.use_shm) - - def _read_files_and_tokenize(self): - def series_to_item(ls): - import numpy - import pandas - - while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1: - ls = ls[0] - return ls - - dataframes = [] - for parquet_file in self.parquet_files: - # read parquet files and cache - dataframe = pd.read_parquet(parquet_file) - dataframes.append(dataframe) - self.dataframe = pd.concat(dataframes) - - total = len(self.dataframe) - print(f"dataset len: {len(self.dataframe)}") - - if self.max_samples > 0 and self.max_samples < total: - if self.shuffle: - rngs_args = (self.seed,) if self.seed is not None else () - rng = np.random.default_rng(*rngs_args) - indices = rng.choice(total, size=self.max_samples, replace=False) - else: - indices = np.arange(self.max_samples) - self.dataframe = self.dataframe.iloc[indices.tolist()] - print(f"selected {self.max_samples} random samples out of {total}") - - self.prompts = self.dataframe[self.prompt_key] - for key in self.prompt_dict_keys: - # type(x): pandas.core.series.Series - # type(x[0]): numpy.ndarray - # type(x[0][0]): dict - try: - self.prompts = self.prompts.apply(lambda x: series_to_item(x)[key], axis=1) # noqa: B023 - except Exception: - print(f"self.prompts={self.prompts}") - raise - if isinstance(self.prompts, pd.DataFrame): - self.prompts = self.prompts.squeeze() - self.prompts = self.prompts.tolist() - self.responses = self.dataframe[self.response_key] - for key in self.response_dict_keys: - try: - self.responses = self.responses.apply(lambda x: series_to_item(x)[key], axis=1) # noqa: B023 - except Exception: - print(f"self.responses={self.responses}") - raise - if isinstance(self.responses, pd.DataFrame): - self.responses = self.responses.squeeze() - self.responses = self.responses.tolist() - - def __len__(self): - return len(self.prompts) - - def __getitem__(self, item): - tokenizer = self.tokenizer - - prompt = self.prompts[item] - response = self.responses[item] - - # apply chat template - prompt_chat = [{"role": "user", "content": prompt}] - - # string - prompt_chat_str = tokenizer.apply_chat_template( - prompt_chat, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs - ) - response_chat_str = response + tokenizer.eos_token - - # tokenize - prompt_ids_output = tokenizer(prompt_chat_str, return_tensors="pt", add_special_tokens=False) - prompt_ids = prompt_ids_output["input_ids"][0] - prompt_attention_mask = prompt_ids_output["attention_mask"][0] - - response_ids_output = tokenizer(response_chat_str, return_tensors="pt", add_special_tokens=False) - response_ids = response_ids_output["input_ids"][0] - response_attention_mask = response_ids_output["attention_mask"][0] - - prompt_length = prompt_ids.shape[0] - response_length = response_ids.shape[0] - - input_ids = torch.cat((prompt_ids, response_ids), dim=-1) - attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) - - # padding to max length - sequence_length = input_ids.shape[0] - if sequence_length < self.max_length: - padded_input_ids = ( - torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) - * self.tokenizer.pad_token_id - ) - padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) - - input_ids = torch.cat((input_ids, padded_input_ids)) - attention_mask = torch.cat((attention_mask, padded_attention_mask)) - elif sequence_length > self.max_length: - if self.truncation == "left": - # actually, left truncation may not be reasonable - input_ids = input_ids[-self.max_length :] - attention_mask = attention_mask[-self.max_length :] - elif self.truncation == "right": - input_ids = input_ids[: self.max_length] - attention_mask = attention_mask[: self.max_length] - elif self.truncation == "error": - raise NotImplementedError(f"{sequence_length=} is larger than {self.max_length=}") - else: - raise NotImplementedError(f"Unknown truncation method {self.truncation}") - - position_ids = compute_position_id_with_mask(attention_mask) - - loss_mask = attention_mask.clone() - if prompt_length > 1: - # mask out prompt for SFT. - loss_mask[: min(prompt_length, loss_mask.size(0)) - 1] = 0 - # mask out the last token in response - loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0 - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "loss_mask": loss_mask, - } diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 70f8cf600ee..8f7f8bef0d0 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -144,7 +144,7 @@ def lambda_policy_fn(module): @torch.no_grad() def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): - if fsdp_version(model) == 2: + if fsdp_version(model) == 2 or fsdp_version(model) == 0: offload_fsdp2_model_to_cpu(model, empty_cache) return @@ -178,7 +178,7 @@ def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): @torch.no_grad() def load_fsdp_model_to_gpu(model: FSDP): - if fsdp_version(model) == 2: + if fsdp_version(model) == 2 or fsdp_version(model) == 0: load_fsdp2_model_to_gpu(model) return @@ -438,7 +438,7 @@ def get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True ): state_dict = model.state_dict() return state_dict - elif fsdp_version(model) == 2: + elif fsdp_version(model) == 2 or fsdp_version(model) == 0: from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict state_dict_config = StateDictOptions( diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py index 40c6db492c9..3eca28cb6fd 100644 --- a/verl/utils/kernel/kernels.py +++ b/verl/utils/kernel/kernels.py @@ -263,6 +263,7 @@ def efficient_entropy_kernel_general_mainloop( _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + vocab_bound = min((pid_n + 1) * vocab_per_split, vocab_size) for n in range(0, num_pid_n): start_offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N) @@ -308,12 +309,14 @@ def efficient_entropy_kernel_general_mainloop( # scale logits by temperature logits *= rcp_temperature + logits_for_lse = tl.where(offs_bn[None, :] < vocab_bound, logits, float("-inf")) + # update global maximum _max_old = _max - m_pid_n = tl.max(logits, axis=1) + m_pid_n = tl.max(logits_for_lse, axis=1) _max = tl.maximum(_max_old, m_pid_n) - exp_logits = tl.exp(logits - _max[:, None]) + exp_logits = tl.exp(logits_for_lse - _max[:, None]) coeff = tl.exp(_max_old - _max) _accu = coeff * _accu + tl.sum(exp_logits, axis=1) diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 9572fb91962..708c6e24fa5 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -269,6 +269,8 @@ def peft_pre_wrap_hook(model): model = provider.provide_distributed_model( wrap_with_ddp=wrap_config.wrap_with_ddp, ddp_config=ddp_config, + fp16=provider.fp16, + bf16=provider.bf16, ) # Extract TransformerConfig from the created model @@ -440,6 +442,11 @@ def offload_megatron_model_to_cpu(models): # if the grad_data size is already zero, we assume that it is already offloaded buffer.grad_data_size = buffer.grad_data.storage().size() buffer.grad_data.storage().resize_(0) + # Offload frozen parameters not in DDP buffers (e.g. base model in LoRA/PEFT) + # DDP buffers only contain requires_grad=True params, so frozen params must be offloaded separately. + for param in model_chunk.module.parameters(): + if not param.requires_grad and param.device.type != "cpu": + param.data = param.data.to("cpu", non_blocking=True) else: # we need this for ref module for _, param in model_chunk.named_parameters(): @@ -451,7 +458,14 @@ def offload_megatron_model_to_cpu(models): @torch.no_grad() -def load_megatron_model_to_gpu(models, load_grad=True): +def load_megatron_model_to_gpu(models, load_grad=True, load_frozen_params=True): + """ + Load megatron model to GPU. + Args: + models: The model to load. + load_grad: Whether to load gradients. + load_frozen_params: Whether to load frozen parameters. + """ for model_chunk in models: if isinstance(model_chunk, DDP): model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] @@ -466,6 +480,13 @@ def load_megatron_model_to_gpu(models, load_grad=True): buffer.param_data.storage().resize_(buffer.param_data_size) # copy data from cpu to cuda buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True) + + # Load frozen parameters that were offloaded (e.g. base model in LoRA/PEFT) + if load_frozen_params: + device_id = get_device_id() + for param in model_chunk.module.parameters(): + if not param.requires_grad and param.device.type == "cpu": + param.data = param.data.to(device_id, non_blocking=True) else: # we need this for ref module device_id = get_device_id() diff --git a/verl/utils/memory_buffer.py b/verl/utils/memory_buffer.py deleted file mode 100644 index 9386f0d88bc..00000000000 --- a/verl/utils/memory_buffer.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This file contains utilities to manipulate torch memory buffers -""" - -from typing import Optional - -import torch -from torch import nn - -from verl.utils.device import get_device_name - - -class MemoryBuffer: - """ - A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying - memory. It must have a unique type to support this behavior. - """ - - def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Optional[torch.Tensor] = None): - self.numel = numel - self.numel_padded = numel_padded - self.dtype = dtype - if source is not None: - self.data = source - else: - self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_name(), requires_grad=False) - - def zero(self): - """Reset the buffer to zero.""" - self.data.zero_() - - def get(self, shape, start_index): - """Return a tensor with the input `shape` as a view into the - 1-D data starting at `start_index`.""" - end_index = start_index + shape.numel() - assert end_index <= self.numel, "requested tensor is out of the buffer range." - buffer_tensor = self.data[start_index:end_index] - buffer_tensor = buffer_tensor.view(shape) - return buffer_tensor - - -def calc_padded_numel(shape: torch.Size, dtype: torch.dtype): - """for cuda memory alignment, make sure alignment by 128-bits""" - align_numel = 128 // torch.finfo(dtype).bits - numel = shape.numel() - return (numel + align_numel - 1) // align_numel * align_numel - - -def get_weight_buffer_meta_from_module(module: nn.Module) -> dict[str, dict]: - """ - Return a dictionary containing name to a shape and dtype. - """ - weight_buffer_meta = {} - for name, param in sorted(module.named_parameters()): - weight_buffer_meta[name] = {"shape": param.shape, "dtype": param.dtype} - return weight_buffer_meta - - -def build_memory_buffer(weight_buffer_meta: dict[str, dict]) -> dict[torch.dtype, MemoryBuffer]: - """Build the memory buffer given weight_buffer_meta - - Args: - weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors - - Returns: a large memory buffer for each dtype that can hold all the tensors - - """ - memory_buffers = {} - total_numel_map = {} # map from dtype to the total numel - for name, meta_info in sorted(weight_buffer_meta.items()): - shape = meta_info["shape"] - dtype = meta_info["dtype"] - - assert isinstance(shape, torch.Size) - assert isinstance(dtype, torch.dtype) - - if dtype not in total_numel_map: - total_numel_map[dtype] = 0 - - total_numel_map[dtype] += calc_padded_numel(shape, dtype) - - for dtype, total_numel in total_numel_map.items(): - memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype) - - return memory_buffers - - -def build_memory_reference_from_module( - module: torch.nn.Module, memory_buffers: dict[torch.dtype, MemoryBuffer], maintain_weight=True -): - start_index = {} - for dtype in memory_buffers: - start_index[dtype] = 0 - for name, param in sorted(module.named_parameters()): - memory_buffer = memory_buffers[param.dtype] - buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype]) - # need to increment start_index - start_index[param.dtype] += calc_padded_numel(param.shape, param.dtype) - if maintain_weight: - buffer.copy_(param.data) - param.data = buffer - - -def build_memory_reference(weight_buffer_meta: dict[str, dict], memory_buffers: dict[torch.dtype, MemoryBuffer]): - """Build the memory references. The memory buffers are built using the build_memory_buffer API. - This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta. - - Args: - weight_buffer_meta: - memory_buffers: - - Returns: - - """ - start_idx = {} - weight_buffers = {} - for dtype in memory_buffers: - start_idx[dtype] = 0 - - for name, meta_info in sorted(weight_buffer_meta.items()): - shape = meta_info["shape"] - dtype = meta_info["dtype"] - - buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype]) - start_idx[dtype] += calc_padded_numel(shape, dtype) - weight_buffers[name] = buffer - - return weight_buffers - - -class MemoryBufferModuleWrapper: - """ - Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to - - It will change the checkpoint name - """ - - def __init__(self, module: nn.Module): - super().__init__() - self.module = module - self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module) - self.memory_buffers = build_memory_buffer(self.weight_buffer_meta) - build_memory_reference_from_module(self.module, self.memory_buffers) - - def get_memory_buffers(self): - return self.memory_buffers - - def get_weight_buffer_meta(self): - return self.weight_buffer_meta - - -class MegatronMemoryBufferForRollout: - """ - We assume that - - inference engine has tp + dp - - actor has tp + pp + dp - - the tp between inference engine and actor should be the same - - memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer - - weight_buffers: contains a list of weight_buffers, each is a dict from name to param - - named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that - the named_parameters may not be directly compatible with inference engine. User has to take care of - this part such as the layout mismatches. (e.g. qkv transpose) - - Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory. - - When doing weight sync, the data is transfer via memory buffers - """ - - def __init__(self, transform_memory_param_fn): - self._memory_buffers = [] - self._weight_buffers = [] - self._named_parameters = {} - self.transform_memory_param_fn = transform_memory_param_fn - - def initialize_weight_buffer(self, weight_buffer_meta_pp: list[dict[str, dict]]): - """ - Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct - a large buffer for each dtype in the weight_buffer. - - Args: - weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from - - Returns: None - - """ - self.weight_buffer_meta_pp = weight_buffer_meta_pp - - for weight_buffer_meta in self.weight_buffer_meta_pp: - memory_buffer = build_memory_buffer(weight_buffer_meta) - self._memory_buffers.append(memory_buffer) - self._weight_buffers.append(None) - - def build_memory_reference(self): - for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp): - self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i]) - self._named_parameters = self.transform_memory_param_fn(self._weight_buffers) - - @property - def named_parameters(self): - return self._named_parameters - - @property - def weight_buffers(self): - return self._weight_buffers - - @property - def memory_buffers(self): - return self._memory_buffers diff --git a/verl/utils/net_utils.py b/verl/utils/net_utils.py index 1acef76a434..0c6ee35d118 100644 --- a/verl/utils/net_utils.py +++ b/verl/utils/net_utils.py @@ -70,15 +70,22 @@ def is_valid_ipv6_address(address: str) -> bool: return False -def get_free_port(address: str) -> tuple[int, socket.socket]: - family = socket.AF_INET - if is_valid_ipv6_address(address): - family = socket.AF_INET6 +def get_free_port(address: str, with_alive_sock: bool = False) -> tuple[int, socket.socket | None]: + """Find a free port on the given address. + + By default the socket is closed internally, suitable for immediate use. + Set with_alive_sock=True to keep the socket open as a port reservation, + preventing other calls from getting the same port. The caller is + responsible for closing the socket before the port is actually bound + by the target service (e.g. NCCL, uvicorn). + """ + family = socket.AF_INET6 if is_valid_ipv6_address(address) else socket.AF_INET sock = socket.socket(family=family, type=socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.bind((address, 0)) - port = sock.getsockname()[1] - return port, sock + if with_alive_sock: + return port, sock + sock.close() + return port, None diff --git a/verl/utils/profiler/torch_profile.py b/verl/utils/profiler/torch_profile.py index bd59ef54dca..e0ccd09f3d3 100644 --- a/verl/utils/profiler/torch_profile.py +++ b/verl/utils/profiler/torch_profile.py @@ -14,6 +14,7 @@ import functools import os +from datetime import datetime, timezone from typing import Callable, Optional import torch @@ -34,7 +35,11 @@ def get_torch_profiler( os.makedirs(save_path, exist_ok=True) - save_file_name = f"prof_rank-{rank}.json.gz" + current_time = datetime.now(tz=timezone.utc).astimezone() + timestamp = current_time.strftime("%Y%m%d%H%M%S%f")[:-3] + pid = os.getpid() + + save_file_name = f"prof_rank-{rank}_{pid}_{timestamp}.json.gz" if save_file_prefix: save_file_name = f"{save_file_prefix}_{save_file_name}" save_path = os.path.join(save_path, save_file_name) diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py index 32dc8f52681..bb8a60bfea6 100644 --- a/verl/utils/py_functional.py +++ b/verl/utils/py_functional.py @@ -25,6 +25,8 @@ from types import SimpleNamespace from typing import Any, Callable, Iterator, Optional +import numpy as np + from verl.utils.metric import Metric @@ -339,3 +341,28 @@ def convert_to_regular_types(obj): elif isinstance(obj, dict): return {k: convert_to_regular_types(v) for k, v in obj.items()} return obj + + +def convert_nested_value_to_list_recursive(data_item): + if isinstance(data_item, dict): + return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()} + elif isinstance(data_item, list): + return [convert_nested_value_to_list_recursive(elem) for elem in data_item] + elif isinstance(data_item, np.ndarray): + # Convert to list, then recursively process the elements of the new list + return convert_nested_value_to_list_recursive(data_item.tolist()) + else: + # Base case: item is already a primitive type (int, str, float, bool, etc.) + return data_item + + +def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): + if len(list_of_dict) == 0: + return {} + keys = list_of_dict[0].keys() + output = {key: [] for key in keys} + for data in list_of_dict: + for key, item in data.items(): + assert key in output, f"Key '{key}' is not present in the keys of the first dictionary in the list." + output[key].append(item) + return output diff --git a/verl/utils/qat/__init__.py b/verl/utils/qat/__init__.py new file mode 100644 index 00000000000..6f2a85c814d --- /dev/null +++ b/verl/utils/qat/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +QAT (Quantization-Aware Training) module for verl. + +Supports NVFP4 (W4A4 and W4A16) quantization modes for FSDP training. + +Module Structure: +- core.py: QATConfig, apply_qat, enable_qat_fuse (training setup) +- linear.py: QATLinear layer with Triton kernels for fake quantization +- quantizer.py: QATQuantizer for true quantization + scale computation utilities +- vllm_patch.py: Patches for vLLM dynamic weight loading + +Usage: + from verl.utils.qat import apply_qat, QATConfig + + config = QATConfig(enable=True, mode="w4a16") + model = apply_qat(model, config) # Before FSDP wrapping +""" + +from verl.utils.qat.core import ( + QATConfig, + apply_qat, + enable_qat_fuse, + invalidate_all_scales, + load_quantization_config, +) +from verl.utils.qat.vllm_patch import ( + apply_qat_patches, + manual_process_weights_after_loading, + prepare_qat_for_load_weights, +) + +__all__ = [ + # Core + "QATConfig", + "apply_qat", + "load_quantization_config", + "enable_qat_fuse", + "invalidate_all_scales", + # vLLM Patch + "apply_qat_patches", + "manual_process_weights_after_loading", + "prepare_qat_for_load_weights", +] diff --git a/verl/utils/qat/core.py b/verl/utils/qat/core.py new file mode 100644 index 00000000000..9f3a1fbe6a8 --- /dev/null +++ b/verl/utils/qat/core.py @@ -0,0 +1,196 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""QAT (Quantization-Aware Training) utilities for verl FSDP training.""" + +import json +import logging +import re +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch.nn as nn + +from verl.base_config import BaseConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class QATConfig(BaseConfig): + """Unified configuration for QAT (Quantization-Aware Training).""" + + enable: bool = False + mode: str = "w4a16" + group_size: int = 16 + ignore_patterns: list[str] = field(default_factory=lambda: ["lm_head", "embed_tokens", "re:.*mlp.gate$"]) + activation_observer: str = "static_minmax" + quantization_config_path: Optional[str] = None + + +def load_quantization_config(qat_config: QATConfig) -> dict[str, Any]: + """Load quantization config JSON file from QATConfig.""" + if not qat_config.quantization_config_path: + raise ValueError("quantization_config_path is required when QAT is enabled") + + logger.info(f"Loading QAT quantization config from: {qat_config.quantization_config_path}") + + with open(qat_config.quantization_config_path) as f: + quant_config = json.load(f) + + if qat_config.ignore_patterns: + original_ignore = quant_config.get("ignore", []) + quant_config["ignore"] = qat_config.ignore_patterns + if original_ignore != qat_config.ignore_patterns: + logger.info(f"Overriding JSON 'ignore' field: {original_ignore} -> {qat_config.ignore_patterns}") + + logger.info("Successfully loaded QAT quantization config") + return quant_config + + +def _should_quantize(name: str, module: nn.Module, config: QATConfig) -> bool: + """Check if a module should be quantized.""" + if not isinstance(module, nn.Linear): + return False + + for pattern in config.ignore_patterns: + if pattern.startswith("re:"): + regex = pattern[3:] + if re.match(regex, name): + logger.debug(f"Ignoring {name} due to regex pattern: {regex}") + return False + else: + if pattern in name: + logger.debug(f"Ignoring {name} due to pattern: {pattern}") + return False + + if module.in_features % config.group_size != 0: + logger.warning( + f"Skipping {name}: in_features={module.in_features} not divisible by group_size={config.group_size}" + ) + return False + + return True + + +def apply_qat( + model: nn.Module, + config: QATConfig | dict[str, Any], +) -> nn.Module: + """Apply QAT to a model by replacing nn.Linear with QATLinear.""" + from verl.utils.qat.linear import QATLinear, QATMode + + if not isinstance(config, QATConfig): + config = QATConfig(**config) + + if not config.enable: + logger.info("QAT is disabled, returning original model") + return model + + mode = QATMode(config.mode.lower()) + logger.info(f"Applying QAT with mode={mode.value}, group_size={config.group_size}") + + modules_to_replace = [] + for name, module in model.named_modules(): + if _should_quantize(name, module, config): + modules_to_replace.append((name, module)) + + logger.info(f"Found {len(modules_to_replace)} Linear layers to convert to QAT") + + converted_count = 0 + for name, module in modules_to_replace: + if isinstance(module, QATLinear): + continue + + fake_quant_module = QATLinear.from_linear( + module, + mode=mode, + group_size=config.group_size, + activation_observer=config.activation_observer, + ) + + _set_module(model, name, fake_quant_module) + converted_count += 1 + + logger.info(f"Successfully applied QAT to {converted_count} layers") + + return model + + +def _set_module(model: nn.Module, name: str, new_module: nn.Module): + """Set a module in the model by its full name.""" + parts = name.split(".") + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + setattr(parent, parts[-1], new_module) + + +FUSION_PATTERNS = { + "qkv": ["q_proj", "k_proj", "v_proj"], + "gate_up": ["gate_proj", "up_proj"], +} + + +def setup_fusion_siblings(model: nn.Module): + """Setup fusion siblings for QKV and GateUp layers.""" + import weakref + + from verl.utils.qat.linear import QATLinear + + qat_modules = {name: m for name, m in model.named_modules() if isinstance(m, QATLinear)} + + counts = {} + for group_name, suffixes in FUSION_PATTERNS.items(): + groups: dict[str, dict[str, nn.Module]] = {} + for name, module in qat_modules.items(): + for suffix in suffixes: + if name.endswith(suffix): + parent = name.rsplit(".", 1)[0] + groups.setdefault(parent, {})[suffix] = module + + count = 0 + for parent, projs in groups.items(): + if len(projs) >= 2: + modules = list(projs.values()) + for i, m in enumerate(modules): + siblings = modules[:i] + modules[i + 1 :] + m._fusion_siblings_ref = [weakref.ref(s) for s in siblings] + count += 1 + counts[group_name] = count + + logger.info(f"[QAT Fuse] Setup fusion siblings: {counts}") + return counts + + +def enable_qat_fuse(model: nn.Module): + """Enable QAT fuse mode: sets up fusion siblings for weight scale fusion.""" + setup_fusion_siblings(model) + model._qat_fuse_enabled = True + logger.info("[QAT Fuse] Enabled QAT fuse mode") + + +def invalidate_all_scales(model: nn.Module): + """Clear all cached weight scales after optimizer.step().""" + from verl.utils.qat.linear import QATLinear + + count = 0 + for module in model.modules(): + if isinstance(module, QATLinear): + module._weight_blockwise_scale = None + module._weight_global_scale = None + module._cached_weight_amax = None + count += 1 + + logger.debug(f"[QAT Fuse] Invalidated scales for {count} QATLinear layers") diff --git a/verl/utils/qat/linear.py b/verl/utils/qat/linear.py new file mode 100644 index 00000000000..4b6c6bc8f41 --- /dev/null +++ b/verl/utils/qat/linear.py @@ -0,0 +1,385 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""QAT FakeQuantized Linear module for NVFP4 (W4A4/W4A16) with FSDP compatibility. + +Includes Triton kernels for high-performance FP4 quantization. +""" + +from enum import Enum +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["QATLinear", "QATMode"] + + +import triton +import triton.language as tl + +_TORCH_TO_TL_DTYPE = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} +FP4_E2M1_MAX: float = 6.0 +FP8_E4M3_MAX: float = 448.0 + + +@triton.jit +def _fp4_fake_quant_kernel( + x_ptr, + y_ptr, + M, + N, + global_scale_ptr, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + BLOCK_SIZE: tl.constexpr, + TILE_M: tl.constexpr, + TILE_N: tl.constexpr, + NUM_FP4_BLOCKS: tl.constexpr, + OUT_DTYPE: tl.constexpr, + FP4_MAX: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + row_start = pid_m * TILE_M + col_start = pid_n * TILE_N + + x_block_ptr = tl.make_block_ptr( + base=x_ptr, + shape=(M, N), + strides=(stride_xm, stride_xn), + offsets=(row_start, col_start), + block_shape=(TILE_M, TILE_N), + order=(1, 0), + ) + y_block_ptr = tl.make_block_ptr( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + offsets=(row_start, col_start), + block_shape=(TILE_M, TILE_N), + order=(1, 0), + ) + + global_scale = tl.load(global_scale_ptr).to(tl.float32) + global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12) + + tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + tile_reshaped = tl.reshape(tile, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE)) + x_abs = tl.abs(tile_reshaped) + + block_max = tl.max(x_abs, axis=2, keep_dims=True) + block_max_scaled = block_max / (FP4_MAX * global_scale_safe) + block_max_scaled = tl.minimum(block_max_scaled, FP8_MAX) + block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale + block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0) + + block_max_quant_broadcast = tl.broadcast_to(block_max_quant, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE)) + abs_scaled = x_abs / block_max_quant_broadcast + + q_val = tl.where( + abs_scaled <= 0.25, + 0.0, + tl.where( + abs_scaled < 0.75, + 0.5, + tl.where( + abs_scaled <= 1.25, + 1.0, + tl.where( + abs_scaled < 1.75, + 1.5, + tl.where( + abs_scaled <= 2.5, + 2.0, + tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, FP4_MAX)), + ), + ), + ), + ), + ) + + x_rescaled = q_val * block_max_quant_broadcast + x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled) + tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N)) + + tl.store(y_block_ptr, tile_quant.to(OUT_DTYPE), boundary_check=(0, 1)) + + +def fp4_fake_quant_weight( + weight: torch.Tensor, + global_amax: torch.Tensor = None, + block_size: int = 16, + tile_rows: int = 16, + tile_cols: int = 64, +) -> torch.Tensor: + """Apply FP4 fake quantization using Triton kernel.""" + x_shape = weight.shape + x_dtype = weight.dtype + x = weight.reshape(-1, x_shape[-1]).contiguous() + M, N = x.shape + y = torch.empty_like(x) + + stride_xm, stride_xn = x.stride() + stride_ym, stride_yn = y.stride() + + tile_cols = max(tile_cols, block_size) + tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size + num_fp4_blocks = tile_cols_aligned // block_size + + if global_amax is None: + global_amax = weight.abs().max().to(torch.float32) + global_scale = global_amax.float() / (FP4_E2M1_MAX * FP8_E4M3_MAX) + + grid = (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned)) + + _fp4_fake_quant_kernel[grid]( + x, + y, + M, + N, + global_scale, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + BLOCK_SIZE=block_size, + TILE_M=tile_rows, + TILE_N=tile_cols_aligned, + NUM_FP4_BLOCKS=num_fp4_blocks, + OUT_DTYPE=_TORCH_TO_TL_DTYPE[x_dtype], + FP4_MAX=FP4_E2M1_MAX, + FP8_MAX=FP8_E4M3_MAX, + ) + return y.view(*x_shape) + + +class STEFP4QuantTriton(torch.autograd.Function): + """Straight-Through Estimator wrapper for Triton FP4 quantization kernel.""" + + @staticmethod + def forward(ctx, x: torch.Tensor, global_amax: torch.Tensor, block_size: int) -> torch.Tensor: + return fp4_fake_quant_weight(x, global_amax=global_amax, block_size=block_size) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple: + return grad_output, None, None + + +class QATMode(str, Enum): + """QAT quantization mode.""" + + W4A4 = "w4a4" # Weight 4-bit, Activation 4-bit (dynamic) + W4A16 = "w4a16" # Weight 4-bit, Activation 16-bit (weight only) + + +class QATLinear(nn.Linear): + """QAT FakeQuantized Linear layer with FSDP compatibility.""" + + _UNINITIALIZED_SCALE = -1.0 + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + mode: QATMode = QATMode.W4A4, + group_size: int = 16, + activation_observer: str = "static_minmax", # Observer strategy for activation global_scale + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__(in_features, out_features, bias, device=device, dtype=dtype) + + self.mode = mode + self.group_size = group_size + self.activation_observer = activation_observer + + self._weight_blockwise_scale: Optional[torch.Tensor] = None + self._weight_global_scale: Optional[torch.Tensor] = None + self._cached_weight_amax: Optional[torch.Tensor] = None + self._fusion_siblings_ref = None + + if mode == QATMode.W4A4: + self.register_buffer( + "input_global_scale", torch.tensor([self._UNINITIALIZED_SCALE], dtype=torch.float32), persistent=True + ) + + self.register_buffer( + "input_amax", torch.tensor([self._UNINITIALIZED_SCALE], dtype=torch.float32), persistent=True + ) + + self._ema_decay: float = 0.01 + + self.fake_quant_enabled = True + + @classmethod + def from_linear( + cls, + linear: nn.Linear, + mode: QATMode = QATMode.W4A4, + group_size: int = 16, + activation_observer: str = "static_minmax", + ) -> "QATLinear": + """Create QATLinear from an existing nn.Linear.""" + has_bias = linear.bias is not None + + new_linear = cls( + in_features=linear.in_features, + out_features=linear.out_features, + bias=has_bias, + mode=mode, + group_size=group_size, + activation_observer=activation_observer, + device=linear.weight.device, + dtype=linear.weight.dtype, + ) + + if linear.weight.device != torch.device("meta"): + new_linear.weight = nn.Parameter(linear.weight.clone()) + if has_bias: + new_linear.bias = nn.Parameter(linear.bias.clone()) + + return new_linear + + def _is_amax_initialized(self) -> bool: + """Check if input_amax has been initialized.""" + if not hasattr(self, "input_amax"): + return False + return self.input_amax.item() != self._UNINITIALIZED_SCALE + + def _update_input_global_scale(self, x: torch.Tensor): + """Update static input_global_scale based on observer strategy.""" + assert self.mode == QATMode.W4A4, "_update_input_global_scale should only be called in W4A4 mode" + + current_amax = torch.amax(torch.abs(x)).detach().to(torch.float32) + + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + torch.distributed.all_reduce(current_amax, op=torch.distributed.ReduceOp.MAX) + + scale_factor = FP8_E4M3_MAX * FP4_E2M1_MAX + + if self.activation_observer == "memoryless_minmax": + new_scale = (scale_factor / (current_amax + 1e-12)).view(1) + self.input_global_scale.copy_(new_scale.to(self.input_global_scale.device)) + + elif self.activation_observer == "static_minmax": + if not self._is_amax_initialized(): + self.input_amax.copy_(current_amax.view(1).to(self.input_amax.device)) + else: + new_amax = torch.maximum(self.input_amax, current_amax.view(1).to(self.input_amax.device)) + self.input_amax.copy_(new_amax) + amax_f32 = self.input_amax.to(torch.float32) + new_scale = (scale_factor / (amax_f32 + 1e-12)).float().view(1) + self.input_global_scale.copy_(new_scale.to(self.input_global_scale.device)) + + elif self.activation_observer == "minmax": + if not self._is_amax_initialized(): + self.input_amax.copy_(current_amax.view(1).to(self.input_amax.device)) + else: + new_amax = (1 - self._ema_decay) * self.input_amax + self._ema_decay * current_amax.view(1).to( + self.input_amax.device + ) + self.input_amax.copy_(new_amax) + amax_f32 = self.input_amax.to(torch.float32) + new_scale = (scale_factor / (amax_f32 + 1e-12)).float().view(1) + self.input_global_scale.copy_(new_scale.to(self.input_global_scale.device)) + + else: + raise ValueError(f"Unknown activation_observer: {self.activation_observer}") + + def _fake_quantize_weight(self, weight: torch.Tensor) -> torch.Tensor: + """Apply fake quantization to weight tensor using Triton kernel.""" + with torch.no_grad(): + if self._cached_weight_amax is not None: + global_amax = self._cached_weight_amax + else: + siblings_ref = getattr(self, "_fusion_siblings_ref", None) + + if siblings_ref is not None: + siblings = [ref() for ref in siblings_ref if ref() is not None] + siblings = [s for s in siblings if s.weight.device != torch.device("meta")] + + for sibling in siblings: + sibling_amax = getattr(sibling, "_cached_weight_amax", None) + if sibling_amax is not None: + global_amax = sibling_amax + self._cached_weight_amax = global_amax + break + else: + all_modules = [self] + siblings + amaxes = [m.weight.abs().max().to(torch.float32) for m in all_modules] + global_amax = torch.max(torch.stack(amaxes)) + + self._cached_weight_amax = global_amax + for sibling in siblings: + sibling._cached_weight_amax = global_amax + else: + global_amax = weight.abs().max().to(torch.float32) + self._cached_weight_amax = global_amax + + if self._weight_global_scale is None: + self._weight_global_scale = global_amax.float() / (FP4_E2M1_MAX * FP8_E4M3_MAX) + + result = STEFP4QuantTriton.apply(weight, global_amax, self.group_size) + + return result + + def _fake_quantize_activation(self, x: torch.Tensor) -> torch.Tensor: + """Apply fake quantization to activation tensor (W4A4 mode only).""" + original_shape = x.shape + + if x.dim() == 3: + x_2d = x.view(-1, x.shape[-1]) + else: + x_2d = x + + if self.training: + self._update_input_global_scale(x_2d) + + if self.input_global_scale.item() == self._UNINITIALIZED_SCALE: + raise RuntimeError("W4A4 input_global_scale uninitialized. Load PTQ model first.") + + global_amax = (FP4_E2M1_MAX * FP8_E4M3_MAX) / self.input_global_scale.to(x.device) + result = STEFP4QuantTriton.apply(x_2d, global_amax, self.group_size) + return result.view(original_shape) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass with fake quantization.""" + if not self.fake_quant_enabled: + return F.linear(x, self.weight, self.bias) + + weight_fq = self._fake_quantize_weight(self.weight) + + if self.mode == QATMode.W4A4: + x_fq = self._fake_quantize_activation(x) + else: + x_fq = x + + return F.linear(x_fq, weight_fq, self.bias) + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"bias={self.bias is not None}, mode={self.mode.value}, " + f"group_size={self.group_size}, fake_quant_enabled={self.fake_quant_enabled}" + ) diff --git a/verl/utils/qat/quantizer.py b/verl/utils/qat/quantizer.py new file mode 100644 index 00000000000..5ed96d46a3e --- /dev/null +++ b/verl/utils/qat/quantizer.py @@ -0,0 +1,308 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Fast NVFP4 Quantizer for verl FSDP training. + +Directly computes scales and quantizes weights using compressed_tensors APIs. +Includes scale computation utilities for weight quantization. +""" + +import logging +import os +import re +from typing import Generator, Iterable, Optional + +import torch +from compressed_tensors.compressors.quantized_compressors.fp4_quantized import NVFP4PackedCompressor +from compressed_tensors.quantization.quant_args import ( + FP4_E2M1_DATA, + FP8_E4M3_DATA, + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) +from compressed_tensors.quantization.utils.helpers import generate_gparam + +from verl.utils.device import get_device_name, get_torch_device + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +_LAYER_IDX_RE = re.compile(r"layers\.(\d+)\.") + + +def compute_blockwise_scale( + weight: torch.Tensor, + global_scale: torch.Tensor, + group_size: int = 16, +) -> torch.Tensor: + """Compute blockwise scale using pre-computed global_scale (for fusion). + Returns FP8 E4M3 blockwise scale tensor. + """ + out_features, in_features = weight.shape + num_groups = in_features // group_size + weight_reshaped = weight.view(out_features, num_groups, group_size) + block_max = torch.amax(torch.abs(weight_reshaped), dim=-1).to(torch.float32) + + local_scale = block_max / FP4_E2M1_DATA.max + blockwise_scale_f32 = torch.clamp( + global_scale * local_scale, + min=-FP8_E4M3_DATA.max, + max=FP8_E4M3_DATA.max, + ) + + blockwise_scale = blockwise_scale_f32.to(torch.float8_e4m3fn) + eps = torch.finfo(torch.float8_e4m3fn).eps + blockwise_scale = torch.where( + blockwise_scale == 0, + torch.tensor(eps, dtype=blockwise_scale.dtype, device=weight.device), + blockwise_scale, + ) + + return blockwise_scale + + +# Fusion patterns for transformer models +FUSE_PATTERNS = { + "qkv": ["q_proj", "k_proj", "v_proj"], + "gate_up": ["gate_proj", "up_proj"], +} + + +def fuse_global_scales( + layer_global_scales: dict[str, torch.Tensor], + strategy: str = "min", +) -> dict[str, torch.Tensor]: + """Fuse global scales for QKV/GateUp groups (take min across group).""" + if not layer_global_scales: + return {} + + # Group by parent module + parent_to_children: dict[str, dict[str, str]] = {} + for name in layer_global_scales: + parent, child = name.rsplit(".", 1) if "." in name else ("", name) + parent_to_children.setdefault(parent, {})[child] = name + + fused_scales = {} + processed = set() + + for parent, children in parent_to_children.items(): + for _, patterns in FUSE_PATTERNS.items(): + matched = [children[p] for p in patterns if p in children] + if len(matched) == len(patterns): + group_scales = [layer_global_scales[n] for n in matched] + if strategy == "min": + fused_scale = torch.min(torch.cat(group_scales)).reshape([1]) + else: + raise ValueError(f"Unknown fuse strategy: {strategy}") + for layer_name in matched: + fused_scales[layer_name] = fused_scale.clone() + processed.add(layer_name) + + for name, scale in layer_global_scales.items(): + if name not in processed: + fused_scales[name] = scale + + return fused_scales + + +class QATQuantizer: + """Quantizer for QAT-trained weights using compressed_tensors APIs.""" + + def __init__( + self, + mode: str = "w4a16", + group_size: int = 16, + ignore_patterns: Optional[list] = None, + device: Optional[torch.device] = None, + param_dtype: Optional[torch.dtype] = None, + ): + self.mode = mode.lower() + self._is_w4a4 = self.mode == "w4a4" # W4A4 needs input_global_scale + self.group_size = group_size + self.ignore_patterns = ignore_patterns or ["lm_head", "embed_tokens", "re:.*mlp.gate$"] + self.device = device or torch.device(get_device_name()) + self.param_dtype = param_dtype + + self._compressor = NVFP4PackedCompressor() + self._quant_args = QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + symmetric=True, + strategy=QuantizationStrategy.TENSOR_GROUP, + group_size=group_size, + scale_dtype=FP8_E4M3_DATA.dtype, + ) + + def _should_quantize(self, name: str, tensor: torch.Tensor) -> bool: + """Check if parameter should be quantized.""" + if not name.endswith(".weight"): + return False + if tensor.dim() != 2: + return False + if tensor.shape[1] % self.group_size != 0: + return False + + module_name = name.rsplit(".weight", 1)[0] + + for pattern in self.ignore_patterns: + if pattern.startswith("re:"): + # Regex pattern - use re.match like vLLM does + regex = pattern[3:] + if re.match(regex, module_name): + return False + else: + if pattern in module_name: + return False + return True + + @staticmethod + def _extract_layer_idx(name: str) -> Optional[int]: + """Extract decoder layer index from parameter name.""" + match = _LAYER_IDX_RE.search(name) + return int(match.group(1)) if match else None + + def _process_layer_group( + self, + layer_idx: Optional[int], + layer_params: dict[str, torch.Tensor], + input_global_scales: dict[str, torch.Tensor], + output_device: torch.device, + ) -> list[tuple[str, torch.Tensor]]: + """Quantize one decoder layer's buffered params. Returns list of (name, tensor).""" + layer_weights = {} + layer_passthrough = {} + + for name, tensor in layer_params.items(): + if "input_global_scale" in name or "input_amax" in name: + continue + + if self._should_quantize(name, tensor): + layer_name = name.rsplit(".weight", 1)[0] + layer_weights[layer_name] = (name, tensor) + else: + layer_passthrough[name] = tensor + + if layer_idx is None and layer_weights: + raise RuntimeError( + f"[QAT Quantizer] Unexpected quantizable weights outside decoder layers: " + f"{list(layer_weights.keys())}. These should be in ignore_patterns." + ) + + if not layer_weights: + return [(name, tensor.to(output_device)) for name, tensor in layer_passthrough.items()] + + # Move weights to GPU, compute global scales + weights_on_gpu = {} + layer_global_scales = {} + + for layer_name, (_, tensor) in layer_weights.items(): + weight_gpu = tensor.to(device=self.device, dtype=self.param_dtype) + weights_on_gpu[layer_name] = weight_gpu + amax = torch.amax(torch.abs(weight_gpu)).to(torch.float32) + layer_global_scales[layer_name] = generate_gparam( + -amax.unsqueeze(0), + amax.unsqueeze(0), + scale_data=FP8_E4M3_DATA, + quant_data=FP4_E2M1_DATA, + dtype=torch.float32, + ) + + fused_global_scales = fuse_global_scales(layer_global_scales, strategy="min") + + results = [] + + for layer_name, weight_gpu in weights_on_gpu.items(): + fused_global_scale = fused_global_scales[layer_name] + weight_scale = compute_blockwise_scale(weight_gpu, fused_global_scale, self.group_size) + weight_packed = self._compressor.compress_weight( + weight=weight_gpu, + scale=weight_scale.float(), + global_scale=fused_global_scale, + quantization_args=self._quant_args, + )["weight_packed"] + + results.append((f"{layer_name}.weight_packed", weight_packed.to(output_device))) + results.append((f"{layer_name}.weight_scale", weight_scale.to(output_device))) + results.append((f"{layer_name}.weight_global_scale", fused_global_scale.to(output_device))) + + if self._is_w4a4: + if layer_name in input_global_scales: + results.append( + ( + f"{layer_name}.input_global_scale", + input_global_scales[layer_name].float().to(output_device), + ) + ) + else: + raise ValueError( + f"W4A4 mode requires input_global_scale for layer '{layer_name}', " + f"but it's not found or uninitialized (-1.0)." + ) + + del weights_on_gpu, layer_global_scales, fused_global_scales + + for name, tensor in layer_passthrough.items(): + results.append((name, tensor.to(output_device))) + + return results + + def quantize_with_fusion( + self, + params: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], + target_device: Optional[torch.device] = None, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + """Streaming quantize: consume input layer by layer, yield (name, tensor) pairs.""" + if isinstance(params, dict): + params = params.items() + + output_device = target_device or torch.device("cpu") + + _sentinel = object() + current_layer_idx = _sentinel + layer_buffer: dict[str, torch.Tensor] = {} + input_global_scales: dict[str, torch.Tensor] = {} + for name, tensor in params: + tensor_cpu = tensor.to("cpu") if tensor.is_cuda else tensor + layer_idx = self._extract_layer_idx(name) + + # Collect input_global_scales for W4A4 as we go + if self._is_w4a4 and "input_global_scale" in name: + scale_layer_name = name.replace(".input_global_scale", "") + if tensor_cpu.numel() == 1 and tensor_cpu.item() == -1.0: + logger.warning(f"W4A4: {scale_layer_name} input_global_scale is uninitialized") + else: + input_global_scales[scale_layer_name] = tensor_cpu + + # Layer boundary: flush previous layer + if layer_idx != current_layer_idx and current_layer_idx is not _sentinel and layer_buffer: + yield from self._process_layer_group( + current_layer_idx, layer_buffer, input_global_scales, output_device + ) + layer_buffer = {} + + current_layer_idx = layer_idx + layer_buffer[name] = tensor_cpu + + # Flush last buffered layer + if layer_buffer: + yield from self._process_layer_group(current_layer_idx, layer_buffer, input_global_scales, output_device) + + get_torch_device().empty_cache() + + +__all__ = [ + "QATQuantizer", +] diff --git a/verl/utils/qat/vllm_patch.py b/verl/utils/qat/vllm_patch.py new file mode 100644 index 00000000000..77d2ec17eeb --- /dev/null +++ b/verl/utils/qat/vllm_patch.py @@ -0,0 +1,828 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +vLLM NVFP4 Patches for Dynamic Weight Updates. + +Enables dynamic weight reloading for NVFP4 quantized models in vLLM. + +Supported schemes: +- Dense: W4A16-FP4, W4A4-FP4 +- MoE: NVFP4-MoE +""" + +import logging +import os +from typing import Optional +from unittest.mock import patch + +import torch +from torch.nn import Parameter + +from verl.utils.device import get_device_name + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class ParamMetaDict(dict): + """ + Dict-like class for parameter management with metadata-based rebuild and tensor swap. + + Supports: + - Rebuild of deleted parameters from saved metadata + - Tensor Swap for parameters with shape changes (address stability for CUDA Graph) + """ + + def __init__(self, model: torch.nn.Module, device: Optional[torch.device] = None): + """ + Initialize ParamMetaDict from a model. + + Args: + model: vLLM model (may be wrapped in ModelRunner) + device: Device for created parameters + """ + super().__init__() + self.device = device + + # Get the actual model (handle vLLM's wrapper structure) + actual_model = model + if hasattr(model, "model"): + actual_model = model.model + self._model = actual_model + + # Build mappings by scanning all modules + self._layer_meta_cache: dict[str, dict] = {} # Cache of _hf_param_meta + self._tensor_swap_layers: dict[str, dict] = {} # Layers needing tensor swap + + self._build_mappings() + + # Initialize with current parameters + for name, param in actual_model.named_parameters(): + self[name] = param + + def _build_mappings(self): + """Build layer metadata cache for rebuild and tensor swap.""" + for layer_name, module in self._model.named_modules(): + # Check for _hf_param_meta which indicates this layer has HF format params + if hasattr(module, "_hf_param_meta"): + self._layer_meta_cache[layer_name] = { + "module": module, + "meta": module._hf_param_meta, + } + + # Check for tensor swap layers (weight_scale with shape change) + if "weight_scale" in module._hf_param_meta: + marlin_refs = getattr(module, "_marlin_tensor_refs", {}) + if "weight_scale" in marlin_refs: + self._tensor_swap_layers[layer_name] = { + "module": module, + "marlin_ref": marlin_refs["weight_scale"], + "hf_meta": module._hf_param_meta["weight_scale"], + } + + # MoE layers (w13_weight_scale, w2_weight_scale) + if "w13_weight_scale" in module._hf_param_meta: + marlin_refs = getattr(module, "_marlin_tensor_refs", {}) + if "w13_weight_scale" in marlin_refs: + self._tensor_swap_layers[f"{layer_name}.w13"] = { + "module": module, + "param_name": "w13_weight_scale", + "marlin_ref": marlin_refs["w13_weight_scale"], + "hf_meta": module._hf_param_meta["w13_weight_scale"], + } + if "w2_weight_scale" in marlin_refs: + self._tensor_swap_layers[f"{layer_name}.w2"] = { + "module": module, + "param_name": "w2_weight_scale", + "marlin_ref": marlin_refs["w2_weight_scale"], + "hf_meta": module._hf_param_meta["w2_weight_scale"], + } + + def _try_rebuild(self, key: str) -> Optional[Parameter]: + """ + Try to rebuild a parameter from metadata if it was deleted. + + Args: + key: Full parameter name + + Returns: + Rebuilt parameter or None if cannot rebuild + """ + # Extract layer name and param name + parts = key.rsplit(".", 1) + if len(parts) != 2: + return None + + layer_name, param_name = parts + + # Check if we have metadata for this layer + if layer_name not in self._layer_meta_cache: + return None + + cache_entry = self._layer_meta_cache[layer_name] + module = cache_entry["module"] + meta = cache_entry["meta"] + + # Check if this param needs rebuild + if param_name not in meta: + return None + + # Already exists on module? + if hasattr(module, param_name): + param = getattr(module, param_name) + if param is not None: + return param + + # Rebuild from metadata + new_param = _create_param_from_meta(module, param_name, meta[param_name], self.device) + module.register_parameter(param_name, new_param) + return new_param + + def prepare_for_reload(self) -> None: + """Replace Marlin-format tensors with HF-shape tensors for reload.""" + for layer_name, swap_info in self._tensor_swap_layers.items(): + module = swap_info["module"] + param_name = swap_info.get("param_name", "weight_scale") + hf_meta = swap_info["hf_meta"] + if hasattr(module, param_name): + new_param = _create_param_from_meta(module, param_name, hf_meta, self.device) + setattr(module, param_name, new_param) + + def __getitem__(self, key: str) -> Parameter: + """Get parameter with rebuild support.""" + # Try standard lookup first + if key in dict.keys(self): + return super().__getitem__(key) + + # Try rebuild from metadata + param = self._try_rebuild(key) + if param is not None: + self[key] = param + return param + + raise KeyError(f"Parameter not found: {key}") + + def __contains__(self, key: str) -> bool: + """Check if parameter exists (with rebuild check).""" + if super().__contains__(key): + return True + + # Check if can rebuild from metadata + parts = key.rsplit(".", 1) + if len(parts) == 2: + layer_name, param_name = parts + if layer_name in self._layer_meta_cache: + meta = self._layer_meta_cache[layer_name]["meta"] + if param_name in meta: + return True + + return False + + def get(self, key: str, default=None): + """Get parameter with default.""" + try: + return self[key] + except KeyError: + return default + + +def _create_param_from_meta( + module: torch.nn.Module, + param_name: str, + meta: dict, + device: Optional[torch.device] = None, +) -> Parameter: + """Create a Parameter from saved metadata. Used by rebuild and tensor swap.""" + shape = meta["shape"] + dtype = meta["dtype"] + dev = device or meta.get("device", get_device_name()) + param_class = meta.get("param_class", Parameter) + + weight_loaders = getattr(module, "_weight_loaders", {}) + weight_loader = weight_loaders.get(param_name) + + data = torch.empty(shape, dtype=dtype, device=dev) + + try: + if param_class is not Parameter and weight_loader is not None: + kwargs = {"data": data, "weight_loader": weight_loader} + if "input_dim" in meta: + kwargs["input_dim"] = meta["input_dim"] + if "output_dim" in meta: + kwargs["output_dim"] = meta["output_dim"] + new_param = param_class(**kwargs) + else: + new_param = Parameter(data, requires_grad=False) + if weight_loader is not None: + new_param.weight_loader = weight_loader + except Exception as e: + logger.warning(f"Failed to create param {param_name} with class {param_class}: {e}, using Parameter") + new_param = Parameter(data, requires_grad=False) + if weight_loader is not None: + new_param.weight_loader = weight_loader + + if "quant_method" in meta: + new_param.quant_method = meta["quant_method"] + + return new_param + + +def save_param_meta(layer: torch.nn.Module, param_name: str): + """Save parameter metadata for rebuild.""" + if not hasattr(layer, "_hf_param_meta"): + layer._hf_param_meta = {} + + param = getattr(layer, param_name, None) + if param is None: + return + + meta = { + "shape": tuple(param.shape), + "dtype": param.dtype, + "device": str(param.device), + "param_class": type(param), # Save the actual parameter class + } + + # Save vLLM-specific attributes needed for reconstruction + if hasattr(param, "_input_dim"): + meta["input_dim"] = param._input_dim + if hasattr(param, "_output_dim"): + meta["output_dim"] = param._output_dim + + # Save MoE-specific attributes (quant_method is required by weight_loader) + if hasattr(param, "quant_method"): + meta["quant_method"] = param.quant_method + + layer._hf_param_meta[param_name] = meta + + +def _check_first_call(layer: torch.nn.Module) -> bool: + """Check if this is the first process_weights call, and increment counter.""" + count = getattr(layer, "_process_weights_call_count", 0) + layer._process_weights_call_count = count + 1 + return count == 0 + + +# Dense W4A16 Patches +def patched_w4a16_process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Patched process_weights_after_loading for W4A16 Dense layer.""" + import vllm._custom_ops as ops + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + marlin_make_workspace_new, + marlin_permute_scales, + nvfp4_marlin_process_global_scale, + nvfp4_marlin_process_scales, + ) + + is_first_call = _check_first_call(layer) + + group_size = 16 + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + device = layer.weight_packed.device + param_dtype = getattr(layer, "params_dtype", torch.float16) + + # Save metadata (first call only) + if is_first_call: + save_param_meta(layer, "weight_packed") + save_param_meta(layer, "weight_global_scale") + save_param_meta(layer, "weight_scale") + if not hasattr(layer, "_weight_loaders"): + layer._weight_loaders = {} + for pname in ["weight_packed", "weight_global_scale", "weight_scale"]: + param = getattr(layer, pname, None) + if param is not None and hasattr(param, "weight_loader"): + layer._weight_loaders[pname] = param.weight_loader + + # Get HF format data + weight_packed_hf = layer.weight_packed.data + weight_global_scale_hf = layer.weight_global_scale.data + weight_scale_hf = layer.weight_scale.data + + # Create workspace (first call only) + if is_first_call: + layer.workspace = marlin_make_workspace_new(device) + + # Convert to Marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = weight_packed_hf.view(torch.int32).T.contiguous() + marlin_weight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, + is_a_8bit=False, + ) + + weight_scale = weight_scale_hf.T.contiguous().to(param_dtype) + weight_scale_permuted = marlin_permute_scales( + s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=group_size, + is_a_8bit=False, + ) + marlin_weight_scale = nvfp4_marlin_process_scales(weight_scale_permuted) + + weight_scale_2_raw = (1.0 / weight_global_scale_hf.max()).to(param_dtype) + marlin_weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2_raw) + + # Update compute parameters + if is_first_call: + layer.weight = Parameter(marlin_weight, requires_grad=False) + layer.weight_scale = Parameter(marlin_weight_scale, requires_grad=False) + layer.weight_scale_2 = Parameter(marlin_weight_scale_2, requires_grad=False) + if not hasattr(layer, "_marlin_tensor_refs"): + layer._marlin_tensor_refs = {} + layer._marlin_tensor_refs["weight_scale"] = layer.weight_scale.data + else: + layer.weight.data.copy_(marlin_weight) + layer.weight_scale_2.data.copy_(marlin_weight_scale_2) + marlin_scale_ref = layer._marlin_tensor_refs.get("weight_scale") + if marlin_scale_ref is not None: + marlin_scale_ref.copy_(marlin_weight_scale) + layer.weight_scale = Parameter(marlin_scale_ref, requires_grad=False) + else: + logger.warning("W4A16: _marlin_tensor_refs['weight_scale'] not found") + layer.weight_scale = Parameter(marlin_weight_scale, requires_grad=False) + + # Delete HF parameters + if hasattr(layer, "weight_packed"): + delattr(layer, "weight_packed") + if hasattr(layer, "weight_global_scale"): + delattr(layer, "weight_global_scale") + + +def patched_w4a4_process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Patched process_weights_after_loading for W4A4 Dense (all backends).""" + from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale + + is_first_call = _check_first_call(layer) + + _W4A4_HF_PARAMS = ["weight_packed", "weight_scale", "weight_global_scale", "input_global_scale"] + + if is_first_call: + for pname in _W4A4_HF_PARAMS: + save_param_meta(layer, pname) + if not hasattr(layer, "_weight_loaders"): + layer._weight_loaders = {} + for pname in _W4A4_HF_PARAMS: + param = getattr(layer, pname, None) + if param is not None and hasattr(param, "weight_loader"): + layer._weight_loaders[pname] = param.weight_loader + + weight_packed_data = layer.weight_packed.data + weight_scale_data = layer.weight_scale.data + input_global_scale_data = layer.input_global_scale.data + weight_global_scale_data = layer.weight_global_scale.data + + global_input_scale = input_global_scale_data.max().to(torch.float32) + global_weight_scale = weight_global_scale_data.max().to(torch.float32) + + if self.backend == "flashinfer-trtllm": + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + + epilogue_tile_m = 128 + processed_weight = shuffle_matrix_a(weight_packed_data.view(torch.uint8), epilogue_tile_m) + processed_weight_scale = ( + shuffle_matrix_sf_a(weight_scale_data.view(torch.uint8), epilogue_tile_m) + .reshape(weight_scale_data.shape) + .view(torch.float8_e4m3fn) + ) + elif self.backend == "fbgemm": + processed_weight_scale = swizzle_blockscale(weight_scale_data).view(-1).view(torch.uint8) + processed_weight = weight_packed_data + else: + # cutlass / flashinfer-cutlass + processed_weight_scale = swizzle_blockscale(weight_scale_data) + processed_weight = weight_packed_data + + alpha = 1.0 / (global_input_scale * global_weight_scale) + + if is_first_call: + layer.weight_packed = Parameter(processed_weight, requires_grad=False) + layer.weight_scale = Parameter(processed_weight_scale, requires_grad=False) + layer.input_global_scale = Parameter(global_input_scale, requires_grad=False) + layer.weight_global_scale = Parameter(global_weight_scale, requires_grad=False) + layer.alpha = Parameter(alpha, requires_grad=False) + + if not hasattr(layer, "_marlin_tensor_refs"): + layer._marlin_tensor_refs = {} + layer._marlin_tensor_refs["weight_packed"] = layer.weight_packed.data + layer._marlin_tensor_refs["weight_scale"] = layer.weight_scale.data + layer._marlin_tensor_refs["input_global_scale"] = layer.input_global_scale.data + layer._marlin_tensor_refs["weight_global_scale"] = layer.weight_global_scale.data + layer._marlin_tensor_refs["alpha"] = layer.alpha.data + else: + refs = layer._marlin_tensor_refs + for ref_name, new_data in [ + ("weight_packed", processed_weight), + ("weight_scale", processed_weight_scale), + ("input_global_scale", global_input_scale), + ("weight_global_scale", global_weight_scale), + ("alpha", alpha), + ]: + ref = refs.get(ref_name) + if ref is not None: + ref.copy_(new_data) + setattr(layer, ref_name, Parameter(ref, requires_grad=False)) + else: + logger.warning(f"W4A4: _marlin_tensor_refs['{ref_name}'] not found, creating new Parameter") + setattr( + layer, + ref_name, + Parameter( + new_data.clone() if isinstance(new_data, torch.Tensor) else torch.tensor(new_data), + requires_grad=False, + ), + ) + + +def _marlin_repack_experts(packed, perm, size_k, size_n, num_experts): + """Repack weight for each expert into Marlin format and stack.""" + import vllm._custom_ops as ops + + result = [] + for i in range(num_experts): + qweight = packed[i].view(torch.int32).T.contiguous() + result.append( + ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + is_a_8bit=False, + ) + ) + return torch.stack(result) + + +def _marlin_process_scales_experts(scale_hf, param_dtype, size_k, size_n, group_size, num_experts): + """Process scales for each expert into Marlin format and stack.""" + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + marlin_permute_scales, + nvfp4_marlin_process_scales, + ) + + result = [] + scales = scale_hf.to(param_dtype) + for i in range(num_experts): + s = marlin_permute_scales( + s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=group_size, + is_a_8bit=False, + ) + result.append(nvfp4_marlin_process_scales(s)) + return torch.stack(result) + + +def _process_nvfp4_moe_marlin(self, layer: torch.nn.Module, is_first_call: bool) -> None: + """Process MoE layer with MARLIN backend (W4A16).""" + from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import make_nvfp4_moe_kernel + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + marlin_make_workspace_new, + nvfp4_marlin_process_global_scale, + ) + + group_size = 16 + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + device = layer.w13_weight_packed.device + param_dtype = layer.params_dtype + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + + if is_first_call: + layer.workspace = marlin_make_workspace_new(device, 4) + + perm = torch.empty(0, dtype=torch.int, device=device) + + if self.moe.is_act_and_mul and not torch.allclose( + layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] + ): + logger.warning("w1_weight_global_scale must match w3_weight_global_scale. Accuracy may be affected.") + + size_n_w13, size_k_w13 = n * w13_num_shards, k + size_n_w2, size_k_w2 = k, n + + w13_weight_marlin = _marlin_repack_experts(layer.w13_weight_packed.data, perm, size_k_w13, size_n_w13, e) + w2_weight_marlin = _marlin_repack_experts(layer.w2_weight_packed.data, perm, size_k_w2, size_n_w2, e) + w13_weight_scale_marlin = _marlin_process_scales_experts( + layer.w13_weight_scale.data, param_dtype, size_k_w13, size_n_w13, group_size, e + ) + w2_weight_scale_marlin = _marlin_process_scales_experts( + layer.w2_weight_scale.data, param_dtype, size_k_w2, size_n_w2, group_size, e + ) + + # Process global scales + w13_scale_2 = 1.0 / layer.w13_weight_global_scale[:, 0] + w2_scale_2 = 1.0 / layer.w2_weight_global_scale.data + w13_scale_2_processed = nvfp4_marlin_process_global_scale(w13_scale_2.to(param_dtype)) + w2_scale_2_processed = nvfp4_marlin_process_global_scale(w2_scale_2.to(param_dtype)) + + # Update parameters + if is_first_call: + layer.w13_weight = Parameter(w13_weight_marlin, requires_grad=False) + layer.w2_weight = Parameter(w2_weight_marlin, requires_grad=False) + layer.w13_weight_scale = Parameter(w13_weight_scale_marlin, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale_marlin, requires_grad=False) + layer.w13_weight_scale_2 = Parameter(w13_scale_2_processed, requires_grad=False) + layer.w2_weight_scale_2 = Parameter(w2_scale_2_processed, requires_grad=False) + if not hasattr(layer, "_marlin_tensor_refs"): + layer._marlin_tensor_refs = {} + layer._marlin_tensor_refs["w13_weight_scale"] = layer.w13_weight_scale.data + layer._marlin_tensor_refs["w2_weight_scale"] = layer.w2_weight_scale.data + else: + layer.w13_weight.data.copy_(w13_weight_marlin) + layer.w2_weight.data.copy_(w2_weight_marlin) + layer.w13_weight_scale_2.data.copy_(w13_scale_2_processed) + layer.w2_weight_scale_2.data.copy_(w2_scale_2_processed) + w13_marlin_ref = layer._marlin_tensor_refs.get("w13_weight_scale") + w2_marlin_ref = layer._marlin_tensor_refs.get("w2_weight_scale") + if w13_marlin_ref is not None: + w13_marlin_ref.copy_(w13_weight_scale_marlin) + layer.w13_weight_scale = Parameter(w13_marlin_ref, requires_grad=False) + else: + logger.warning("MoE: _marlin_tensor_refs['w13_weight_scale'] not found") + layer.w13_weight_scale.data.copy_(w13_weight_scale_marlin) + if w2_marlin_ref is not None: + w2_marlin_ref.copy_(w2_weight_scale_marlin) + layer.w2_weight_scale = Parameter(w2_marlin_ref, requires_grad=False) + else: + logger.warning("MoE: _marlin_tensor_refs['w2_weight_scale'] not found") + layer.w2_weight_scale.data.copy_(w2_weight_scale_marlin) + + layer.w13_input_scale = None + layer.w2_input_scale = None + + # Initialize kernel + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config is not None and ( + (not self.moe.moe_parallel_config.use_all2all_kernels) or self.moe.moe_parallel_config.use_naive_all2all_kernels + ): + self.kernel = make_nvfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + ) + + +def _process_nvfp4_moe_flashinfer_cutlass(self, layer: torch.nn.Module, is_first_call: bool) -> None: + """Process MoE layer with FlashInfer/CUTLASS backend (W4A4).""" + from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + convert_to_nvfp4_moe_kernel_format, + make_nvfp4_moe_kernel, + ) + from vllm.model_executor.utils import replace_parameter + + w13_packed = layer.w13_weight_packed.data + w2_packed = layer.w2_weight_packed.data + w13_scale_hf = layer.w13_weight_scale.data + w2_scale_hf = layer.w2_weight_scale.data + + if self.moe.is_act_and_mul and not torch.allclose( + layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] + ): + logger.warning("w1_weight_global_scale must match w3_weight_global_scale. Accuracy may be affected.") + w13_weight_global_scale = layer.w13_weight_global_scale[:, 0].contiguous() + + w13_temp = Parameter(w13_packed.clone(), requires_grad=False) + w2_temp = Parameter(w2_packed.clone(), requires_grad=False) + + if is_first_call: + layer.w13_weight = w13_temp + layer.w2_weight = w2_temp + + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = convert_to_nvfp4_moe_kernel_format( + nvfp4_backend=self.nvfp4_backend, + layer=layer, + w13=w13_temp, + w13_scale=w13_scale_hf, + w13_scale_2=(1.0 / w13_weight_global_scale), + a13_scale=(1.0 / layer.w13_input_global_scale), + w2=w2_temp, + w2_scale=w2_scale_hf, + w2_scale_2=(1.0 / layer.w2_weight_global_scale), + a2_scale=(1.0 / layer.w2_input_global_scale), + is_act_and_mul=self.moe.is_act_and_mul, + ) + + # Update parameters + if is_first_call: + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + layer.w13_weight_scale = Parameter(w13_scale, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_scale, requires_grad=False) + if not hasattr(layer, "_marlin_tensor_refs"): + layer._marlin_tensor_refs = {} + layer._marlin_tensor_refs["w13_weight_scale"] = layer.w13_weight_scale.data + layer._marlin_tensor_refs["w2_weight_scale"] = layer.w2_weight_scale.data + else: + layer.w13_weight.data.copy_(w13.data) + layer.w2_weight.data.copy_(w2.data) + w13_scale_ref = layer._marlin_tensor_refs.get("w13_weight_scale") + w2_scale_ref = layer._marlin_tensor_refs.get("w2_weight_scale") + if w13_scale_ref is not None: + w13_scale_ref.copy_(w13_scale) + layer.w13_weight_scale = Parameter(w13_scale_ref, requires_grad=False) + else: + logger.warning("MoE W4A4: _marlin_tensor_refs['w13_weight_scale'] not found") + layer.w13_weight_scale.data.copy_(w13_scale) + if w2_scale_ref is not None: + w2_scale_ref.copy_(w2_scale) + layer.w2_weight_scale = Parameter(w2_scale_ref, requires_grad=False) + else: + logger.warning("MoE W4A4: _marlin_tensor_refs['w2_weight_scale'] not found") + layer.w2_weight_scale.data.copy_(w2_scale) + + layer.w13_weight_scale_2 = w13_scale_2 + layer.w2_weight_scale_2 = w2_scale_2 + layer.w13_input_scale = a13_scale + layer.w2_input_scale = a2_scale + + # Initialize kernel + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config is not None and ( + (not self.moe.moe_parallel_config.use_all2all_kernels) or self.moe.moe_parallel_config.use_naive_all2all_kernels + ): + self.kernel = make_nvfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + ) + + +# MoE NVFP4 Patches (entry points) +def patched_nvfp4_moe_process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Patched process_weights_after_loading for NVFP4 MoE layer.""" + from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import NvFp4MoeBackend + + is_first_call = _check_first_call(layer) + + # Save metadata (first call only) + if is_first_call: + save_param_meta(layer, "w13_weight_packed") + save_param_meta(layer, "w2_weight_packed") + save_param_meta(layer, "w13_weight_scale") + save_param_meta(layer, "w2_weight_scale") + if not hasattr(layer, "_weight_loaders"): + layer._weight_loaders = {} + for pname in ["w13_weight_packed", "w2_weight_packed", "w13_weight_scale", "w2_weight_scale"]: + param = getattr(layer, pname, None) + if param is not None and hasattr(param, "weight_loader"): + layer._weight_loaders[pname] = param.weight_loader + + is_marlin = self.nvfp4_backend == NvFp4MoeBackend.MARLIN + if is_marlin: + _process_nvfp4_moe_marlin(self, layer, is_first_call) + else: + _process_nvfp4_moe_flashinfer_cutlass(self, layer, is_first_call) + + # Delete HF parameters + if hasattr(layer, "w13_weight_packed"): + delattr(layer, "w13_weight_packed") + if hasattr(layer, "w2_weight_packed"): + delattr(layer, "w2_weight_packed") + + +_PATCH_TARGETS = [ + # Dense W4A16 + ( + "vllm.model_executor.layers.quantization.compressed_tensors.schemes." + "compressed_tensors_w4a16_nvfp4.CompressedTensorsW4A16Fp4.process_weights_after_loading", + patched_w4a16_process_weights_after_loading, + ), + # Dense W4A4 + ( + "vllm.model_executor.layers.quantization.compressed_tensors.schemes." + "compressed_tensors_w4a4_nvfp4.CompressedTensorsW4A4Fp4.process_weights_after_loading", + patched_w4a4_process_weights_after_loading, + ), + # MoE NVFP4 + ( + "vllm.model_executor.layers.quantization.compressed_tensors." + "compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoEMethod.process_weights_after_loading", + patched_nvfp4_moe_process_weights_after_loading, + ), +] + +_applied_patches = [] + + +def apply_qat_patches(): + """Apply NVFP4 patches to support dynamic weight updates. Call before model loading.""" + global _applied_patches + + if _applied_patches: + logger.warning("QAT patches already applied, skipping") + return _applied_patches + + logger.info("Applying NVFP4 patches for dynamic weight loading...") + + for target, replacement in _PATCH_TARGETS: + p = patch(target, replacement) + _applied_patches.append(p) + p.start() + + logger.info(f"Applied {len(_applied_patches)} NVFP4 patches for dynamic weight loading") + return _applied_patches + + +def prepare_qat_for_load_weights(model, device=None): + """ + Prepare QAT model for weight loading. Call ONCE before multi-bucket weight loading. + + Args: + model: vLLM model + device: Device for created parameters + """ + inner_model = model + if hasattr(model, "model"): + inner_model = model.model + + param_meta = ParamMetaDict(inner_model, device=device) + + param_meta.prepare_for_reload() + logger.info(f"[prepare_qat] Tensor swap prepared for {len(param_meta._tensor_swap_layers)} layers") + + # Rebuild deleted (W4A16) or overwritten (W4A4) params back to HF format + rebuilt_count = 0 + for layer_name, cache_entry in param_meta._layer_meta_cache.items(): + module = cache_entry["module"] + for param_name, pm in cache_entry["meta"].items(): + existing = getattr(module, param_name, None) + if existing is not None: + hf_shape = tuple(pm["shape"]) + hf_dtype = pm["dtype"] + if ( + tuple(existing.shape) == hf_shape + and existing.dtype == hf_dtype + and hasattr(existing, "weight_loader") + ): + continue + new_param = _create_param_from_meta(module, param_name, pm, device) + module.register_parameter(param_name, new_param) + rebuilt_count += 1 + + logger.info(f"[prepare_qat] Rebuilt {rebuilt_count} parameters") + inner_model._param_meta_for_restore = param_meta + return param_meta + + +def manual_process_weights_after_loading(model): + """Trigger weight post-processing for all quantized layers after load_weights.""" + dense_count = 0 + moe_count = 0 + + actual_model = model + if hasattr(model, "model"): + actual_model = model.model + + for module in actual_model.modules(): + if hasattr(module, "scheme"): + module.scheme.process_weights_after_loading(module) + dense_count += 1 + + quant_method = getattr(module, "quant_method", None) + if quant_method is not None and not hasattr(module, "scheme"): + if hasattr(quant_method, "process_weights_after_loading"): + # Skip KV cache quantization methods + if "KVCache" in quant_method.__class__.__name__: + continue + quant_method.process_weights_after_loading(module) + moe_count += 1 + + logger.debug(f"Processed {dense_count} dense layers, {moe_count} MoE layers") + return dense_count + moe_count + + +__all__ = [ + "apply_qat_patches", + "prepare_qat_for_load_weights", + "manual_process_weights_after_loading", +] diff --git a/verl/utils/ray_utils.py b/verl/utils/ray_utils.py index 5ba20649365..eff3d91085f 100644 --- a/verl/utils/ray_utils.py +++ b/verl/utils/ray_utils.py @@ -97,9 +97,13 @@ def get_event_loop(): def auto_await(func): """Auto await a coroutine function. - If the function is called in an async context (with a running event loop), - it will return the coroutine object. Otherwise, it will block the current thread - and run the coroutine until completion. + Handles three cases: + 1. When the decorated function is called with await: returns the coroutine + so the caller can await it. + 2. When called directly and there is no running event loop: runs the + coroutine with asyncio.run() and returns the result. + 3. When called directly and the event loop is already running: runs the + coroutine (e.g. in a thread pool to avoid deadlock) and returns the result. """ @functools.wraps(func) @@ -114,9 +118,22 @@ def wrapper(*args, **kwargs): except RuntimeError: loop = None - if loop and loop.is_running(): - return coro - else: + # Case 1: No running loop -> run with asyncio.run() + if loop is None: return asyncio.run(coro) + # Case 2: Running loop -> return coro if caller will await + caller_frame = inspect.currentframe() + if caller_frame is not None: + caller_frame = caller_frame.f_back + caller_is_async = caller_frame is not None and (caller_frame.f_code.co_flags & inspect.CO_COROUTINE) != 0 + if caller_is_async: + return coro + + # Case 3: Running loop -> run coro in thread pool + # (cannot block the loop thread without deadlock) + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(asyncio.run, coro) + return future.result() + return wrapper diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index 46f82240448..51097f50a51 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -388,7 +388,7 @@ def rearrange_micro_batches( if min_num_micro_batch is not None: # used to support pp num_micro_batches = max(min_num_micro_batch, num_micro_batches) - if dist.is_initialized() and same_micro_num_in_dp: + if dist.is_initialized() and same_micro_num_in_dp and dp_group is not None: num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name()) dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) num_micro_batches = num_micro_batches.cpu().item() diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 2802e3642f1..8666bec2d16 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -710,6 +710,7 @@ def get_cosine_schedule_with_warmup( num_cycles: float = 0.5, last_epoch: int = -1, init_lr_ratio: float = None, + zero_indexed_step: bool = True, ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the @@ -731,6 +732,9 @@ def get_cosine_schedule_with_warmup( The index of the last epoch when resuming training. init_lr_ratio (:obj:`float`, `optional`, defaults to None): The initial lr ratio w.r.t the maximum. + zero_indexed_step (:obj:`bool`, `optional`, defaults to True): + Whether the LR schedule uses 0-indexed steps. If True (default), step counting starts at 0. + If False (used by torchtitan), step counting starts at 1. Return: :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ @@ -743,6 +747,8 @@ def get_cosine_schedule_with_warmup( assert init_lr_ratio >= 0 and init_lr_ratio <= 1.0 def lr_lambda(current_step): + if not zero_indexed_step: + current_step += 1 if current_step < num_warmup_steps: return init_lr_ratio + (1.0 - init_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps))) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py deleted file mode 100644 index 6014f4bc03e..00000000000 --- a/verl/utils/transferqueue_utils.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import functools -import inspect -import logging -import os -import threading -from functools import wraps -from typing import TYPE_CHECKING, Any, Callable - -if TYPE_CHECKING: - from verl.single_controller.base.decorator import Dispatch - -from tensordict import TensorDict - -try: - from transfer_queue import ( - AsyncTransferQueueClient, - BatchMeta, - TransferQueueClient, - ) - -except ImportError: - # TODO: Use a hacky workaround for ImportError since - # transfer_queue isn't a default verl dependency. - class BatchMeta: - pass - - -from verl.protocol import DataProto - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -_TRANSFER_QUEUE_CLIENT = None - -is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False) - - -def create_transferqueue_client( - client_id: str, - config, - sync: bool = False, -) -> "AsyncTransferQueueClient | TransferQueueClient": - global _TRANSFER_QUEUE_CLIENT - if _TRANSFER_QUEUE_CLIENT is None: - if sync: - _TRANSFER_QUEUE_CLIENT = TransferQueueClient(client_id, config.controller_info) - else: - _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, config.controller_info) - _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=config.storage_backend, config=config) - - return _TRANSFER_QUEUE_CLIENT - - -def get_transferqueue_client() -> "AsyncTransferQueueClient | TransferQueueClient": - return _TRANSFER_QUEUE_CLIENT - - -# TODO (TQ): verl will make all actor async, so this can be cleanup later. -def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any: - # Use a temporary event loop in a new thread because event - # loop may already exist in server mode - tmp_event_loop = asyncio.new_event_loop() - thread = threading.Thread( - target=tmp_event_loop.run_forever, - name="batchmeta dataproto converter", - daemon=True, - ) - - def run_coroutine(coroutine): - if not thread.is_alive(): - thread.start() - future = asyncio.run_coroutine_threadsafe(coroutine, tmp_event_loop) - return future.result() - - async def stop_loop(): - tmp_event_loop.stop() - - try: - return run_coroutine(async_func(*args, **kwargs)) - finally: - if thread.is_alive(): - asyncio.run_coroutine_threadsafe(stop_loop(), tmp_event_loop) - thread.join() - - -def _find_batchmeta(*args, **kwargs): - for arg in args: - if isinstance(arg, BatchMeta): - return arg - for v in kwargs.values(): - if isinstance(v, BatchMeta): - return v - return None - - -async def _async_batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: - if batchmeta.samples == [] or batchmeta.samples is None: - return DataProto( - batch=TensorDict({}, batch_size=(0,)), - non_tensor_batch={}, - meta_info=batchmeta.extra_info.copy(), - ) - - tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) - return DataProto.from_tensordict(tensordict, meta_info=batchmeta.extra_info.copy()) - - -def _batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: - return _run_async_in_temp_loop(_async_batchmeta_to_dataproto, batchmeta) - - -async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", func_name=None) -> "BatchMeta": - pid = os.getpid() - - for k, v in output.meta_info.items(): - batchmeta.set_extra_info(k, v) - - if len(output) > 0: - tensordict = output.to_tensordict() - # pop meta_info - for key in output.meta_info.keys(): - tensordict.pop(key) - - logger.info( - f"Task {func_name} (pid={pid}) putting output data to TransferQueue with " - f"batch_size={tensordict.batch_size},\n" - f"tensordict keys={list(tensordict.keys())}" - ) - - updated_batch_meta = await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) - return updated_batch_meta - else: - return batchmeta - - -def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", func_name=None) -> "BatchMeta": - updated_batch_meta = _run_async_in_temp_loop(_async_update_batchmeta_with_output, output, batchmeta, func_name) - return updated_batch_meta - - -def _compute_need_collect(dispatch_mode: "dict | Dispatch", args: list) -> bool: - """Compute whether data collection is needed for the current worker. - - This function determines whether the current worker should collect data based on - the dispatch mode configuration and worker parameters. It's used to optimize - distributed data collection by ensuring only the appropriate rank collects data. - - Args: - dispatch_mode: Controls data collection logic for the current worker. Can be None, - a Dispatch instance, or a dict with 'collect_fn' key. If None or Dispatch, - always returns True (current worker should collect). If dict, checks - collect_fn for lazy compute optimization. - args: List of arguments passed to the function. Should contain a Worker instance - as the first argument when using lazy compute mode. - - Returns: - bool: True if data collection is needed, False otherwise. - - Note: - Only checks worker attributes when dispatch_mode is a dict with 'collect_fn', - the collect_fn is 'collect_lazy_compute_data_proto', and args[0] is a Worker. - Otherwise, returns True. For the lazy compute case, checks the worker's - data parallel rank for the mesh specified in collect_fn.args[0] to determine - if this worker should collect data. - """ - from verl.single_controller.base.decorator import Dispatch - from verl.single_controller.base.worker import Worker - - if dispatch_mode is None or isinstance(dispatch_mode, Dispatch): - return True - - assert "collect_fn" in dispatch_mode.keys(), "collect_fn should be in dispatch_mode." - - collect_fn = dispatch_mode["collect_fn"] - - # Check if collect_fn is a functools.partial and handle gracefully - if isinstance(collect_fn, functools.partial): - collect_fn_name = collect_fn.func.__name__ - if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker): - return True - - collect_mesh_name = collect_fn.args[0] if collect_fn.args else None - if collect_mesh_name is None: - return True - - return args[0].query_collect_info(collect_mesh_name) - else: - # If collect_fn is not a partial, we can't extract mesh_name information - # Fall back to default behavior (collect data) - return True - - -def _postprocess_common(output, put_data, need_collect): - """Common post-processing logic for function outputs in TransferQueue bridge. - - This function handles the final return value based on whether data should be - put into storage (put_data) and whether collection is needed (need_collect). - It ensures proper return types based on the execution context. - - Args: - output: The original output from the decorated function. Can be any type. - put_data: bool, indicating whether the output should be put into TransferQueue. - If True, output will be put to TQ and return the corresponding BatchMeta; - if False, output will not be put into TQ. - need_collect: bool, indicating whether this process needs to collect data. - If False, the output will be replaced by an empty BatchMeta or DataProto - to avoid redundant communication. - - Returns: - - BatchMeta.empty(): When put_data=True but need_collect=False, indicating - no data should be stored but BatchMeta structure is expected. - - DataProto(): When put_data=False, need_collect=False, and output is DataProto, - returning an empty DataProto. - - output: In all other cases, returns the original output unchanged. - - Note: - This function is used in the tqbridge decorator to normalize return values - across different execution paths and avoid redundant data operations in - distributed scenarios. - """ - if put_data and not need_collect: - return BatchMeta.empty() - elif not put_data and not need_collect and isinstance(output, DataProto): - return DataProto() - else: - return output - - -def tqbridge(dispatch_mode: "dict | Dispatch" = None, put_data: bool = True): - """Creates a decorator for bridging BatchMeta and DataProto. - - This decorator automatically handles conversions between `BatchMeta` and - `DataProto` in function parameters, and decides whether to sync function - output back to `BatchMeta` based on configuration(`put_data`). It supports - both synchronous and asynchronous functions (async def), and can control - whether to enable enhanced logic via the global `HAS_TQ` variable (when disabled, - simply calls the original function as-is). - - Args: - dispatch_mode: Controls data collection behavior for the current worker. Passed to - _compute_need_collect to determine if current worker should collect data. - If None, _compute_need_collect will return True to fallback default logics. - put_data: Whether put the DataProto into Storage after func return. - If True, after function execution, the output result will be - updated to `BatchMeta` and `BatchMeta` will be returned; - If False, the function output result will be returned directly. - Defaults to True. - - Returns: - A decorator function used to decorate target functions (synchronous or asynchronous). - """ - - def decorator(func): - pid = os.getpid() - - @wraps(func) - def inner(*args, **kwargs): - batchmeta = _find_batchmeta(*args, **kwargs) - if batchmeta is None: - return func(*args, **kwargs) - else: - logger.info( - f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, " - f"global_idx={batchmeta.global_indexes}" - ) - args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] - kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()} - output = func(*args, **kwargs) - need_collect = _compute_need_collect(dispatch_mode, args) - if put_data and need_collect: - updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__) - return updated_batch_meta - return _postprocess_common(output, put_data, need_collect) - - @wraps(func) - async def async_inner(*args, **kwargs): - batchmeta = _find_batchmeta(*args, **kwargs) - if batchmeta is None: - return await func(*args, **kwargs) - else: - logger.info( - f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, " - f"global_idx={batchmeta.global_indexes}" - ) - args = [await _async_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] - kwargs = { - k: await _async_batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v - for k, v in kwargs.items() - } - output = await func(*args, **kwargs) - need_collect = _compute_need_collect(dispatch_mode, args) - if put_data and need_collect: - updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__) - return updated_batchmeta - return _postprocess_common(output, put_data, need_collect) - - @wraps(func) - def dummy_inner(*args, **kwargs): - output = func(*args, **kwargs) - return output - - @wraps(func) - async def dummy_async_inner(*args, **kwargs): - output = await func(*args, **kwargs) - return output - - wrapper_inner = inner if is_transferqueue_enabled else dummy_inner - wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner - - wrapper = wrapper_async_inner if inspect.iscoroutinefunction(func) else wrapper_inner - return wrapper - - return decorator diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index d524f0e2ba1..c13ffa09c2b 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -412,6 +412,13 @@ def _optimizer_step(self): self.actor_optimizer.zero_grad() else: self.actor_optimizer.step() + + # Clear cached weight scales for QAT (weights changed) + if getattr(self.actor_module, "_qat_fuse_enabled", False): + from verl.utils.qat import invalidate_all_scales + + invalidate_all_scales(self.actor_module) + return grad_norm @GPUMemoryLogger(role="dp actor", logger=logger) diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 7fdaa6e9811..f4a697866ab 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -139,6 +139,11 @@ def __init__( assert self.mtp_config.enable, "MTP requires mtp_config.enable to be True" self.use_fused_kernels = self.config.get("use_fused_kernels", False) + if getattr(self.mtp_config, "enable", False) and self.use_fused_kernels: + self.use_fused_kernels = False + logger.warning_once( + "MTP is not compatible with fused kernels for now. Automatically disable use_fused_kernels." + ) if self.use_fused_kernels and not getattr(self.config, "overlap_moe_expert_parallel_comm", False): # do not patch if overlap_moe_expert_parallel_comm is enabled logger.warning_once( @@ -434,6 +439,7 @@ def forward_backward_batch( temperature = data.meta_info["temperature"] if use_dynamic_bsz: assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + dp_group = mpu.get_data_parallel_group() vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage @@ -441,13 +447,16 @@ def forward_backward_batch( batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len, + dp_group=dp_group, ) assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " f"{microbatch_group_size_per_vp_stage} for megatron backend" ) else: - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) + micro_batches, indices = rearrange_micro_batches( + batch=mini_batch.batch, max_token_len=max_token_len, dp_group=dp_group + ) total_seqlen = max_token_len else: assert micro_batch_size is not None, ( diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py index bcf05e8f2d2..66071e8ec20 100644 --- a/verl/workers/config/actor.py +++ b/verl/workers/config/actor.py @@ -18,10 +18,11 @@ from omegaconf import MISSING from verl.base_config import BaseConfig -from verl.trainer.config import CheckpointConfig +from verl.trainer.config import CheckpointConfig, RolloutCorrectionConfig from verl.utils.profiler.config import ProfilerConfig +from verl.utils.qat import QATConfig -from .engine import FSDPEngineConfig, McoreEngineConfig, VeOmniEngineConfig +from .engine import FSDPEngineConfig, McoreEngineConfig, TorchtitanEngineConfig, VeOmniEngineConfig from .model import HFModelConfig from .optimizer import OptimizerConfig @@ -32,6 +33,8 @@ "FSDPActorConfig", "McoreActorConfig", "VeOmniActorConfig", + "QATConfig", + "TorchTitanActorConfig", ] @@ -77,6 +80,7 @@ class PolicyLossConfig(BaseConfig): clip_cov_ub (float): Upper bound for clip-cov loss. kl_cov_ratio (float): Ratio of tokens to be applied KL penalty for kl-cov loss. ppo_kl_coef (float): KL divergence penalty coefficient. + rollout_correction (RolloutCorrectionConfig): Configuration for rollout correction. """ loss_mode: str = "vanilla" @@ -85,6 +89,7 @@ class PolicyLossConfig(BaseConfig): clip_cov_ub: float = 5.0 kl_cov_ratio: float = 0.0002 ppo_kl_coef: float = 0.1 + rollout_correction: RolloutCorrectionConfig = field(default_factory=RolloutCorrectionConfig) @dataclass @@ -294,6 +299,7 @@ class FSDPActorConfig(ActorConfig): use_rollout_log_probs: bool = False calculate_sum_pi_squared: bool = False sum_pi_squared_checkpointing: bool = False + qat: QATConfig = field(default_factory=QATConfig) def __post_init__(self): """Validate FSDP actor configuration parameters.""" @@ -336,3 +342,27 @@ def __post_init__(self): """Validate VeOmni actor configuration parameters.""" super().__post_init__() self.engine = self.veomni + + +@dataclass +class TorchTitanActorConfig(ActorConfig): + """Configuration for TorchTitan actor models. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + strategy (str): Training strategy set to 'torchtitan' for TorchTitan parallelism. + torchtitan (TorchtitanEngineConfig): Configuration for TorchTitan engine settings. + use_remove_padding (bool): Whether to remove padding tokens in inputs during training + use_rollout_log_probs (bool): Whether to use log probabilities from rollout engine + """ + + strategy: str = "torchtitan" + torchtitan: TorchtitanEngineConfig = field(default_factory=TorchtitanEngineConfig) + use_remove_padding: bool = False + use_rollout_log_probs: bool = False + + def __post_init__(self): + """Validate TorchTitan actor configuration parameters.""" + super().__post_init__() + self.engine = self.torchtitan diff --git a/verl/workers/config/critic.py b/verl/workers/config/critic.py index c347b54e754..caca5bac6ac 100644 --- a/verl/workers/config/critic.py +++ b/verl/workers/config/critic.py @@ -22,11 +22,11 @@ from verl.trainer.config import BaseModelConfig, CheckpointConfig from verl.utils.profiler import ProfilerConfig -from .engine import FSDPEngineConfig, McoreEngineConfig +from .engine import FSDPEngineConfig, McoreEngineConfig, TorchtitanEngineConfig from .model import HFModelConfig from .optimizer import OptimizerConfig -__all__ = ["CriticConfig", "FSDPCriticConfig", "McoreCriticConfig", "FSDPCriticModelCfg"] +__all__ = ["CriticConfig", "FSDPCriticConfig", "McoreCriticConfig", "TorchTitanCriticConfig", "FSDPCriticModelCfg"] @dataclass @@ -224,6 +224,26 @@ def validate(self, n_gpus: int, train_batch_size: int): ) +@dataclass +class TorchTitanCriticConfig(CriticConfig): + """Configuration for TorchTitan-based critic model training. + + The inheritance from CriticConfig provides all base critic configuration plus TorchTitan-specific settings. + + Args: + strategy (str): Training strategy set to 'torchtitan' for TorchTitan parallelism. + torchtitan (TorchtitanEngineConfig): Configuration for TorchTitan engine settings. + """ + + strategy: str = "torchtitan" + torchtitan: TorchtitanEngineConfig = field(default_factory=TorchtitanEngineConfig) + + def __post_init__(self): + """Validate TorchTitan critic configuration parameters.""" + super().__post_init__() + self.engine = self.torchtitan + + @dataclass class FSDPCriticModelCfg(BaseModelConfig): """FSDP-enabled critic model configuration. diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index feb559b374c..6caa57d3190 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -27,6 +27,7 @@ "FSDPEngineConfig", "McoreEngineConfig", "TrainingWorkerConfig", + "TorchtitanEngineConfig", "VeOmniEngineConfig", "EngineConfig", "EngineRouterReplayConfig", @@ -73,6 +74,8 @@ class EngineConfig(BaseConfig): "infer_micro_batch_size_per_gpu", "use_fused_kernels", "use_remove_padding", + "forward_only", + "param_offload", } # whether to offload param param_offload: bool = False @@ -235,14 +238,9 @@ class VeOmniEngineConfig(EngineConfig): optimizer_offload (bool): Whether to offload optimizer states to CPU, default False offload_policy (bool): Whether to offload policy model parameters, default False reshard_after_forward (bool): Whether to reshard parameters after forward pass, default True - data_parallel_size (int): FSDP group size, default 1 - data_parallel_replicate_size (int): Data parallel replicate size, default 1 - data_parallel_shard_size (int): Data parallel shard degree, default 1 - tensor_parallel_size (int): Tensor parallel size, default 1 - expert_parallel_size (int): Expert parallel size, default 1 - pipeline_parallel_size (int): Pipeline parallel size, default 1 - context_parallel_size (int): Ring-attn context parallel size, default 1 + fsdp_size (int): FSDP group size. -1 means use all available GPUs, default -1 ulysses_parallel_size (int): Ulysses sequence parallel size, default 1 + expert_parallel_size (int): Expert parallel size, default 1 init_device (str): Device to initialize model weights. 1. `cpu`: Init parameters on CPU in rank0 only. 2. `cuda`: Init parameters on GPU. @@ -291,14 +289,9 @@ class VeOmniEngineConfig(EngineConfig): use_torch_compile: bool = True entropy_checkpointing: bool = False strategy: str = "veomni" - data_parallel_size: int = 1 - data_parallel_replicate_size: int = 1 - data_parallel_shard_size: int = 1 - tensor_parallel_size: int = 1 - expert_parallel_size: int = 1 - pipeline_parallel_size: int = 1 - context_parallel_size: int = 1 + fsdp_size: int = -1 ulysses_parallel_size: int = 1 + expert_parallel_size: int = 1 seed: int = 42 full_determinism: bool = False mixed_precision: bool = False @@ -319,6 +312,66 @@ def __post_init__(self): assert self.strategy in ["veomni"], f"strategy {self.strategy} not supported" +@dataclass +class TorchtitanEngineConfig(EngineConfig): + """Configuration for Torchtitan. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + wrap_policy (Dict[str, Any]): Configuration for FSDP wrap policy. + reshard_after_forward (Literal["default", "always", "never"]): The policy for applying + `reshard_after_forward` within an FSDP setup, default "default" + forward_prefetch (bool): Whether to prefetch parameters for next forward pass, default False + use_orig_params (bool): Whether to use original parameters when initialize FSDP, default False + mixed_precision (bool): Mixed precision configuration for FSDP, default False + offload_policy (bool): Whether to offload policy model parameters, default False + data_parallel_size (int): Data parallel group size, default 1 + data_parallel_replicate_size (int): Data parallel replicate size, default 1 + data_parallel_shard_size (int): Data parallel shard degree, default 1 + tensor_parallel_size (int): Tensor parallel size, default 1 + expert_parallel_size (int): Expert parallel size, default 1 + expert_tensor_parallel_size (int): Expert tensor parallel size, default 1 + pipeline_parallel_size (int): Pipeline parallel size, default 1 + context_parallel_size (int): Context parallel size, default 1 + attn_type (str): Attention type for torchtitan's model (e.g., "sdpa", "flex", "varlen"), + default "flex" + strategy (str): Strategy to use for distributed training, default "torchtitan" + seed (int): Random seed for reproducibility. + full_determinism (bool): If true, enable_full_determinism is called to ensure reproducible results + in distributed training. Important: this will negatively impact performance, so only use it for + debugging. + + """ + + wrap_policy: dict[str, Any] = field(default_factory=dict) + reshard_after_forward: Literal["default", "always", "never"] = "default" + forward_prefetch: bool = False + use_orig_params: bool = False + mixed_precision: bool = False + offload_policy: bool = False + use_torch_compile: bool = True + entropy_from_logits_with_chunking: bool = False + entropy_checkpointing: bool = False + data_parallel_size: int = 1 + data_parallel_replicate_size: int = 1 + data_parallel_shard_size: int = 1 + tensor_parallel_size: int = 1 + expert_parallel_size: int = 1 + expert_tensor_parallel_size: int = 1 + pipeline_parallel_size: int = 1 + context_parallel_size: int = 1 + attn_type: str = "flex" + max_seq_len: Optional[int] = None + strategy: str = "torchtitan" + seed: int = 42 + full_determinism: bool = False + + def __post_init__(self): + super().__post_init__() + assert self.strategy in ["torchtitan"], f"strategy {self.strategy} not supported" + + @dataclass class TrainingWorkerConfig(BaseConfig): model_type: str = None # model type (language_model/value_model) diff --git a/verl/workers/config/model.py b/verl/workers/config/model.py index 1aa2afc0843..9205a99f038 100644 --- a/verl/workers/config/model.py +++ b/verl/workers/config/model.py @@ -119,7 +119,7 @@ class HFModelConfig(BaseConfig): # fsdp lora related. We may setup a separate config later lora_rank: int = 0 lora_alpha: int = 16 - target_modules: Optional[str] = "all-linear" + target_modules: Optional[Any] = "all-linear" # allow both "all-linear" and ["q_proj","k_proj"] target_parameters: Optional[list[str]] = None # for lora adapter on nn.Parameter exclude_modules: Optional[str] = None @@ -204,5 +204,19 @@ def __post_init__(self): if getattr(self.hf_config, "model_type", None) == "kimi_vl": self.hf_config.text_config.topk_method = "greedy" + # Ensure target_modules is a str or list[str] (only if not None) + if self.target_modules is not None: + if not isinstance(self.target_modules, (str | list)): + raise TypeError( + "target_modules must be a string or a list of strings, " + f"but got {type(self.target_modules).__name__}" + ) + if isinstance(self.target_modules, list): + for x in self.target_modules: + if not isinstance(x, str): + raise TypeError( + f"All elements in target_modules list must be strings, but found {type(x).__name__}" + ) + def get_processor(self): return self.processor if self.processor is not None else self.tokenizer diff --git a/verl/workers/config/optimizer.py b/verl/workers/config/optimizer.py index bdb87667c25..b7f05bef518 100644 --- a/verl/workers/config/optimizer.py +++ b/verl/workers/config/optimizer.py @@ -19,7 +19,14 @@ from verl.base_config import BaseConfig -__all__ = ["OptimizerConfig", "FSDPOptimizerConfig", "McoreOptimizerConfig", "build_optimizer", "VeOmniOptimizerConfig"] +__all__ = [ + "OptimizerConfig", + "FSDPOptimizerConfig", + "McoreOptimizerConfig", + "build_optimizer", + "VeOmniOptimizerConfig", + "TorchtitanOptimizerConfig", +] @dataclass @@ -88,6 +95,8 @@ class FSDPOptimizerConfig(OptimizerConfig): min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule. lr_scheduler_type (str): LR scheduler type: "constant" or "cosine". num_cycles (float): Number of cosine cycles in LR schedule. + zero_indexed_step (bool): Whether the LR schedule uses 0-indexed steps. If True (default), + step counting starts at 0. If False, step counting starts at 1. """ _mutable_fields = OptimizerConfig._mutable_fields.copy() @@ -101,6 +110,7 @@ class FSDPOptimizerConfig(OptimizerConfig): lr_scheduler_type: str = "constant" num_cycles: float = 0.5 override_optimizer_config: Optional[dict] = None + zero_indexed_step: bool = True def __post_init__(self): if self.warmup_style is not None: @@ -143,6 +153,23 @@ class McoreOptimizerConfig(OptimizerConfig): override_optimizer_config: Optional[dict] = None +@dataclass +class TorchtitanOptimizerConfig(OptimizerConfig): + """Torchtitan optimizer configuration extending base OptimizerConfig. + + Args: + name (str): Optimizer name; default is "AdamW". + eps (float): Epsilon value for AdamW optimizer, default 1e-8. + decay_type (str): Weight decay type: "linear", "sqrt", or "cosine". + min_lr_factor (float): Minimum learning rate factor. + """ + + name: str = "AdamW" + eps: float = 1e-8 + decay_type: str = "linear" + min_lr_factor: float = 0.0 + + def build_optimizer(parameters, config: FSDPOptimizerConfig): """Build an optimizer based on the configuration. diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index 3b4e7c121a3..d1d5c8f1768 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -80,6 +80,8 @@ class AgentLoopConfig(BaseConfig): @dataclass class TraceConfig(BaseConfig): + project_name: Optional[str] = None + experiment_name: Optional[str] = None backend: Optional[str] = None token2text: bool = False max_samples_per_step_per_worker: Optional[int] = None @@ -125,7 +127,7 @@ class CheckpointEngineConfig(BaseConfig): """ # Backend for checkpoint engine: naive, nccl, nixl, hccl - backend: Optional[str] = MISSING + backend: Optional[str] = "naive" # Bucket size in MB to transfer multiple weights at one time update_weights_bucket_megabytes: int = 2048 # Additional keyword arguments for checkpoint engine @@ -138,6 +140,8 @@ class RolloutConfig(BaseConfig): name: Optional[str] = MISSING mode: str = "async" + nnodes: int = 0 + n_gpus_per_node: int = 8 temperature: float = 1.0 top_k: int = -1 @@ -238,6 +242,8 @@ class RolloutConfig(BaseConfig): mtp: MtpConfig = field(default_factory=MtpConfig) + qat: Optional[dict] = None + def __post_init__(self): """Validate the rollout config""" # Deprecation warning for mode field - only async mode is supported diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index ecc166cd495..af87cb74c09 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -181,6 +181,7 @@ def forward_backward_batch( indices = None if use_dynamic_bsz: assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + dp_group = mpu.get_data_parallel_group() vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage @@ -188,13 +189,16 @@ def forward_backward_batch( batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len, + dp_group=dp_group, ) assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " f"{microbatch_group_size_per_vp_stage} for megatron backend" ) else: - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) + micro_batches, indices = rearrange_micro_batches( + batch=mini_batch.batch, max_token_len=max_token_len, dp_group=dp_group + ) total_seqlen = max_token_len else: assert micro_batch_size is not None, ( diff --git a/verl/workers/engine/__init__.py b/verl/workers/engine/__init__.py index 7b8be1002c0..8f01080fdcb 100644 --- a/verl/workers/engine/__init__.py +++ b/verl/workers/engine/__init__.py @@ -21,6 +21,14 @@ "FSDPEngineWithLMHead", ] +try: + from .torchtitan import TorchTitanEngine, TorchTitanEngineWithLMHead + + __all__ += ["TorchTitanEngine", "TorchTitanEngineWithLMHead"] +except ImportError: + TorchTitanEngine = None + TorchTitanEngineWithLMHead = None + try: from .veomni import VeOmniEngine, VeOmniEngineWithLMHead diff --git a/verl/workers/engine/fsdp/transformer_impl.py b/verl/workers/engine/fsdp/transformer_impl.py index cba8909ff64..dbe9eb2f4e1 100644 --- a/verl/workers/engine/fsdp/transformer_impl.py +++ b/verl/workers/engine/fsdp/transformer_impl.py @@ -62,9 +62,14 @@ from verl.utils.model import convert_weight_keys, extract_multi_modal_inputs from verl.utils.py_functional import convert_to_regular_types from verl.utils.torch_functional import logprobs_from_logits -from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs +from verl.utils.ulysses import ( + gather_outputs_and_unpad, + get_ulysses_sequence_parallel_group, + set_ulysses_sequence_parallel_group, + ulysses_pad, + ulysses_pad_and_slice_inputs, +) from verl.workers.config import FSDPEngineConfig, FSDPOptimizerConfig, HFModelConfig -from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from ..base import BaseEngine, BaseEngineCtx, EngineRegistry from ..utils import enable_full_determinism, postprocess_batch_func, prepare_micro_batches @@ -190,14 +195,15 @@ def _init_device_mesh(self): self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None + self.ulysses_parallel_group = None self.ulysses_sequence_parallel_size = self.engine_config.ulysses_sequence_parallel_size dp_size = self.get_data_parallel_size() if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( device_name, mesh_shape=(dp_size, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) + self.ulysses_parallel_group = self.ulysses_device_mesh["sp"].get_group() - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 def _build_module(self): @@ -406,6 +412,7 @@ def _build_lr_scheduler(self, optimizer): lr_scheduler_type = optim_config.lr_scheduler_type min_lr_ratio = optim_config.min_lr_ratio num_cycles = optim_config.num_cycles + zero_indexed_step = optim_config.zero_indexed_step if num_warmup_steps <= 0: num_warmup_steps_ratio = optim_config.lr_warmup_steps_ratio num_warmup_steps = int(num_warmup_steps_ratio * total_steps) @@ -422,6 +429,7 @@ def _build_lr_scheduler(self, optimizer): num_training_steps=total_steps, min_lr_ratio=min_lr_ratio, num_cycles=num_cycles, + zero_indexed_step=zero_indexed_step, ) else: raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") @@ -707,12 +715,13 @@ def __init__(self, engine: FSDPEngine, **kwargs): def __enter__(self): assert isinstance(self.engine, FSDPEngine) super().__enter__() - self.engine.ulysses_sharding_manager.__enter__() + self.prev_sp_group = get_ulysses_sequence_parallel_group() + set_ulysses_sequence_parallel_group(self.engine.ulysses_parallel_group) self.engine.module.eval() def __exit__(self, exc_type, exc_value, traceback): assert isinstance(self.engine, FSDPEngine) - self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback) + set_ulysses_sequence_parallel_group(self.prev_sp_group) # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module @@ -732,12 +741,13 @@ def __init__(self, engine: FSDPEngine, **kwargs): def __enter__(self): assert isinstance(self.engine, FSDPEngine) super().__enter__() - self.engine.ulysses_sharding_manager.__enter__() + self.prev_sp_group = get_ulysses_sequence_parallel_group() + set_ulysses_sequence_parallel_group(self.engine.ulysses_parallel_group) self.engine.module.train() def __exit__(self, exc_type, exc_value, traceback): assert isinstance(self.engine, FSDPEngine) - self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback) + set_ulysses_sequence_parallel_group(self.prev_sp_group) self.engine.optimizer_zero_grad() super().__exit__(exc_type, exc_value, traceback) diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 0e3f7ff6a29..5cb0824a96b 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -152,6 +152,10 @@ def _build_tf_config(self): # In case of invalid overrides, we need to make sure some critical params are set correctly provider.params_dtype = self.param_dtype + # Ensure dtype settings propagate to Megatron-Bridge/TE + provider.fp16 = self.param_dtype == torch.float16 + provider.bf16 = self.param_dtype == torch.bfloat16 + # Pass distributed info provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size @@ -315,6 +319,8 @@ def initialize(self): if self.engine_config.forward_only: self.optimizer = None self.lr_scheduler = None + self.to(device="cpu", model=self._is_offload_param, optimizer=False, grad=False) + log_gpu_memory_usage("After offload model during init (forward_only)", logger=logger) return self.optimizer = self._build_optimizer() @@ -598,12 +604,14 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw return {} def get_per_tensor_param(self, base_sync_done=False, **kwargs): - load_megatron_model_to_gpu(self.module, load_grad=False) peft_config = None non_merge_lora_sync = self.peft_cls is not None and not self.model_config.lora.get("merge", False) + adapter_only = base_sync_done and non_merge_lora_sync + # when lora adapter only, we only load adapter weights when base sync is done, otherwise load all weights + load_megatron_model_to_gpu(self.module, load_grad=False, load_frozen_params=not adapter_only) if self.vanilla_bridge: per_tensor_param = self.bridge.export_weights(self.module) - elif base_sync_done and non_merge_lora_sync: + elif adapter_only: # Only export adapter weights peft_config = build_peft_config_for_vllm(self.model_config.lora) per_tensor_param = self.bridge.export_adapter_weights(self.module) diff --git a/verl/workers/engine/torchtitan/__init__.py b/verl/workers/engine/torchtitan/__init__.py new file mode 100644 index 00000000000..345757277af --- /dev/null +++ b/verl/workers/engine/torchtitan/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .transformer_impl import TorchTitanEngine, TorchTitanEngineWithLMHead + +__all__ = ["TorchTitanEngine", "TorchTitanEngineWithLMHead"] diff --git a/verl/workers/engine/torchtitan/transformer_impl.py b/verl/workers/engine/torchtitan/transformer_impl.py new file mode 100644 index 00000000000..002ea20e4ff --- /dev/null +++ b/verl/workers/engine/torchtitan/transformer_impl.py @@ -0,0 +1,736 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The concrete Engine implementation using PyTorch TorchTitan parallelism (FSDP2 + TP + PP) +""" + +import gc +import logging +import os +import re +from contextlib import nullcontext +from typing import Any, Callable, Optional + +import torch +import torch.distributed +from tensordict import TensorDict +from torch.distributed.checkpoint.state_dict import get_model_state_dict +from torch.distributed.tensor import DTensor +from torchtitan.config.job_config import ( + Checkpoint, + Compile, + JobConfig, + LRScheduler, + Model, + Optimizer, + Parallelism, + Training, +) +from torchtitan.distributed import utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input +from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.train import Trainer + +import verl.utils.torch_functional as verl_F +from verl.trainer.config import CheckpointConfig +from verl.utils import tensordict_utils as tu +from verl.utils.dataset.dataset_utils import DatasetPadMode +from verl.utils.debug import log_gpu_memory_usage +from verl.utils.device import get_device_id, get_device_name +from verl.utils.fsdp_utils import ( + load_fsdp_model_to_gpu, + load_fsdp_optimizer, + offload_fsdp_model_to_cpu, + offload_fsdp_optimizer, +) +from verl.utils.model import extract_multi_modal_inputs +from verl.utils.torch_functional import logprobs_from_logits +from verl.workers.config import HFModelConfig, TorchtitanEngineConfig, TorchtitanOptimizerConfig +from verl.workers.engine.torchtitan.utils import ( + derive_torchtitan_name_and_flavor, + enable_fsdp_gradient_division, + get_attention_masks, +) + +from ..base import BaseEngine, BaseEngineCtx, EngineRegistry +from ..utils import enable_full_determinism, postprocess_batch_func, prepare_micro_batches + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + + +class TorchTitanEngine(BaseEngine): + """ + Concrete Engine implementation using PyTorch TorchTitan parallelism. + + Supports model sharding with FSDP2, tensor parallelism, activation/optimizer offloading, + LoRA, and sequence parallelism following the TorchTitan design. + """ + + def __init__( + self, + model_config: HFModelConfig, + engine_config: TorchtitanEngineConfig, + optimizer_config: TorchtitanOptimizerConfig, + checkpoint_config: CheckpointConfig, + ): + """ + Initialize the TorchTitanEngine. + + Sets up distributed device meshes for tensor and data parallelism, LoRA, and offload policies. + + Args: + model_config: Configuration for HuggingFace model. + engine_config: Configuration for FSDP/TorchTitan engine (uses FSDP2). + optimizer_config: Configuration for optimizer. + checkpoint_config: Configuration for checkpointing. + """ + super().__init__() + + self.model_config = model_config + self.engine_config = engine_config + self.optimizer_config = optimizer_config + self.checkpoint_config = checkpoint_config + + # Disable torchtitan's dataloader since verl has its own data loading + # Ideally torchtitan trainer init should not initialize dataloader + import torchtitan.protocols.train_spec as train_spec_module + + original_get_train_spec = train_spec_module.get_train_spec + + def _get_train_spec_without_dataloader(model_name): + train_spec = original_get_train_spec(model_name) + train_spec.build_dataloader_fn = None + return train_spec + + train_spec_module.get_train_spec = _get_train_spec_without_dataloader + + # Derive torchtitan model name and flavor from HF config + torchtitan_name, torchtitan_flavor = derive_torchtitan_name_and_flavor(self.model_config.hf_config) + + # Get train_spec and directly override model_args before Trainer init + train_spec = train_spec_module.get_train_spec(torchtitan_name) + model_args = train_spec.model_args.get(torchtitan_flavor) + if model_args is not None: + if hasattr(model_args, "attn_type"): + model_args.attn_type = self.engine_config.attn_type + + model = Model( + name=torchtitan_name, + flavor=torchtitan_flavor, + hf_assets_path=self.model_config.path, + ) + optimizer = Optimizer( + name=self.optimizer_config.name, + lr=self.optimizer_config.lr, + eps=self.optimizer_config.eps, + beta1=self.optimizer_config.betas[0], + beta2=self.optimizer_config.betas[1], + weight_decay=self.optimizer_config.weight_decay, + ) + + total_steps = self.optimizer_config.total_training_steps + lr_warmup_steps = self.optimizer_config.lr_warmup_steps + if lr_warmup_steps is None or lr_warmup_steps <= 0: + lr_warmup_steps = int(self.optimizer_config.lr_warmup_steps_ratio * total_steps) + + lr_scheduler = LRScheduler( + warmup_steps=lr_warmup_steps, + decay_type=self.optimizer_config.decay_type, + min_lr_factor=self.optimizer_config.min_lr_factor, + ) + parallelism = Parallelism( + data_parallel_replicate_degree=self.engine_config.data_parallel_replicate_size, + data_parallel_shard_degree=self.engine_config.data_parallel_shard_size, + fsdp_reshard_after_forward=self.engine_config.reshard_after_forward, + tensor_parallel_degree=self.engine_config.tensor_parallel_size, + pipeline_parallel_degree=self.engine_config.pipeline_parallel_size, + context_parallel_degree=self.engine_config.context_parallel_size, + expert_parallel_degree=self.engine_config.expert_parallel_size, + expert_tensor_parallel_degree=self.engine_config.expert_tensor_parallel_size, + ) + checkpoint = Checkpoint( + enable=True, + initial_load_in_hf=True, + initial_load_model_only=True, + initial_load_path=model_config.path, + ) + compile = Compile(enable=self.engine_config.use_torch_compile) + training_kwargs = {} + if self.engine_config.max_seq_len is not None: + training_kwargs["seq_len"] = self.engine_config.max_seq_len + if self.engine_config.offload_policy or self.engine_config.forward_only: + training = Training(enable_cpu_offload=True, **training_kwargs) + else: + training = Training(**training_kwargs) + + # Construct Torchtitan's JobConfig + self.config = JobConfig( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + parallelism=parallelism, + checkpoint=checkpoint, + compile=compile, + training=training, + ) + self.trainer = Trainer(self.config) + + self._init_device_mesh() + + # Re-enable FSDP's gradient division for verl's loss scaling. + # TorchTitan disables gradient division by default (for global token normalization), + # but verl's loss function multiplies by dp_size to compensate for gradient averaging. + if self.engine_config.data_parallel_shard_size > 1: + dp_size = self.get_data_parallel_size() + for model_part in self.trainer.model_parts: + enable_fsdp_gradient_division(model_part, dp_size) + + if self.engine_config.full_determinism: + enable_full_determinism(seed=self.engine_config.seed) + + # set FSDP offload params + self._is_offload_param = self.engine_config.param_offload + self._is_offload_optimizer = self.engine_config.optimizer_offload + + if self.engine_config.entropy_from_logits_with_chunking: + entropy_from_logits = verl_F.entropy_from_logits_with_chunking + else: + entropy_from_logits = verl_F.entropy_from_logits + + self.compute_entropy_from_logits = ( + torch.compile(entropy_from_logits, dynamic=True) + if self.engine_config.use_torch_compile + else entropy_from_logits + ) + + @property + def is_param_offload_enabled(self) -> bool: + return self._is_offload_param + + @property + def is_optimizer_offload_enabled(self) -> bool: + return self._is_offload_optimizer + + def is_mp_src_rank_with_outputs(self): + """ + Whether the current rank is the first rank in model parallel group that contains model outputs + """ + is_collect = True + # TP: outputs are on TP rank 0 + if self.parallel_dims.tp > 1: + tp_mesh = self.parallel_dims.get_optional_mesh("tp") + is_collect = is_collect and (tp_mesh.get_local_rank() == 0) + # PP: outputs are on the last PP rank + if self.parallel_dims.pp > 1: + pp_mesh = self.parallel_dims.get_optional_mesh("pp") + is_collect = is_collect and (pp_mesh.get_local_rank() == self.parallel_dims.pp - 1) + # CP: outputs are on CP rank 0 + if self.parallel_dims.cp > 1: + cp_mesh = self.parallel_dims.get_optional_mesh("cp") + is_collect = is_collect and (cp_mesh.get_local_rank() == 0) + return is_collect + + def initialize(self): + """ + Build the model, optimizer, and learning rate scheduler with TorchTitan parallelism. + + Applies device, dtype, and precision configurations, including mixed precision. + Sets up checkpoint manager. + """ + self.module = self.trainer.model_parts + self.checkpointer = self.trainer.checkpointer + # load initial HF weights + self.checkpointer.load() + + if not self.engine_config.forward_only: + self.optimizer = self.trainer.optimizers + self.lr_scheduler = self.trainer.lr_schedulers + else: + self.optimizer = None + self.lr_scheduler = None + + self.to( + device="cpu", + model=self._is_offload_param, + optimizer=self._is_offload_optimizer, + grad=self._is_offload_param, + ) + + log_gpu_memory_usage("After offload model/optimizer/grad during init", logger=logger) + + def _init_device_mesh(self): + """Initialize the device mesh for TorchTitan style parallelism.""" + world_size = torch.distributed.get_world_size() + self.parallel_dims = ParallelDims( + dp_shard=self.engine_config.data_parallel_shard_size, + dp_replicate=self.engine_config.data_parallel_replicate_size, + cp=self.engine_config.context_parallel_size, + tp=self.engine_config.tensor_parallel_size, + pp=self.engine_config.pipeline_parallel_size, + ep=self.engine_config.expert_parallel_size, + etp=self.engine_config.expert_tensor_parallel_size, + world_size=world_size, + ) + self.device_mesh = self.parallel_dims.build_mesh() + + def train_mode(self, **kwargs): + """Return a context manager for training mode.""" + return EngineTrainModeCtx(self, **kwargs) + + def eval_mode(self, **kwargs): + """Return a context manager for evaluation mode.""" + return EngineEvalModeCtx(self, **kwargs) + + def get_data_parallel_rank(self): + mesh = self._get_data_parallel_mesh() + return 0 if mesh is None else mesh.get_local_rank() + + def get_data_parallel_size(self): + return self.engine_config.data_parallel_shard_size * self.engine_config.data_parallel_replicate_size + + def get_data_parallel_group(self): + mesh = self._get_data_parallel_mesh() + if mesh is not None: + return mesh.get_group() + # If world_size == dp_size (e.g. single GPU, or all ranks are DP), + # return WORLD so that collective ops in _postprocess_output + # (allgather_dict_into_dict, all_reduce) still run and produce the + # correct metric aggregation format. + if torch.distributed.get_world_size() == self.get_data_parallel_size(): + return torch.distributed.group.WORLD + return None + + def _get_data_parallel_mesh(self): + """Get the data parallel mesh, handling hybrid/fully/replicate shard modes.""" + mesh = self.parallel_dims.get_optional_mesh(["dp_replicate", "fsdp"]) + if mesh is None: + mesh = self.parallel_dims.get_optional_mesh("fsdp") + if mesh is None: + mesh = self.parallel_dims.get_optional_mesh("dp_replicate") + return mesh + + def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False): + """Perform forward and optionally backward pass on a batch.""" + tu.assign_non_tensor(data, sp_size=self.engine_config.tensor_parallel_size) + + # Compute num_tokens in global batch for loss normalization + batch_num_tokens = data["loss_mask"].sum().to(get_device_id()) + dp_group = self.get_data_parallel_group() + if dp_group is not None: + torch.distributed.all_reduce(batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=dp_group) + tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item()) + tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size()) + + micro_batches, indices = prepare_micro_batches( + data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True + ) + + output_lst = [] + + ctx = torch.no_grad() if forward_only else nullcontext() + + for micro_batch in micro_batches: + with ctx: + loss, output = self.forward_step(micro_batch, loss_function=loss_function, forward_only=forward_only) + if not forward_only: + loss.backward() + output_lst.append(output) + + return postprocess_batch_func(output_lst=output_lst, indices=indices, data=data) + + def model_forward_step( + self, + *, + inputs: torch.Tensor, + extra_inputs: dict[str, torch.Tensor] | None = None, + extra_kwargs: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + """ + Perform a forward pass through the trainer model without backward. + """ + model_parts = self.module + parallel_dims = self.parallel_dims + + if parallel_dims.pp_enabled: + raise NotImplementedError( + "Pipeline parallelism is not yet supported in model_forward_step. " + "This will be implemented in a follow-up PR." + ) + else: + # Non-PP forward + assert len(model_parts) == 1 + with self.trainer.train_context(): + with self.trainer.maybe_enable_amp: + pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) + + if isinstance(pred, DTensor): + pred = pred.full_tensor() + return pred + + def forward_step(self, micro_batch: TensorDict, loss_function, forward_only): + raise NotImplementedError("forward_step must be implemented in subclass") + + def optimizer_zero_grad(self): + """Zero gradients.""" + self.optimizer.zero_grad() + + def optimizer_step(self): + """Perform optimizer step with gradient clipping.""" + grad_norm = dist_utils.clip_grad_norm_( + [p for m in self.module for p in m.parameters()], + self.config.training.max_norm, + foreach=True, + pp_mesh=self.parallel_dims.get_optional_mesh("pp"), + ep_enabled=self.parallel_dims.ep_enabled, + ) + + # if grad_norm is not finite, skip the update + if not torch.isfinite(grad_norm): + logger.warning(f"grad_norm is not finite: {grad_norm}") + self.optimizer.zero_grad() + else: + self.optimizer.step() + return grad_norm.item() + + def lr_scheduler_step(self): + """Advance learning rate scheduler.""" + self.lr_scheduler.step() + lr = self.lr_scheduler.schedulers[0].get_last_lr()[0] + return lr + + def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True): + """Move model and/or optimizer to CPU or GPU.""" + super().to(device=device, model=model, optimizer=optimizer, grad=grad) + + if self.engine_config.forward_only: + return + + device_name = get_device_name() + assert device in (device_name, "cpu") + if device == device_name: + if model: + for module in self.module: + load_fsdp_model_to_gpu(module) + if optimizer and self.optimizer is not None: + load_fsdp_optimizer(self.optimizer, device) + gc.collect() + elif device == "cpu": + if model: + for module in self.module: + offload_fsdp_model_to_cpu(module) + if optimizer and self.optimizer is not None: + offload_fsdp_optimizer(self.optimizer) + else: + raise ValueError(f"Invalid device type: {device}") + + def save_checkpoint( + self, + local_path: str, + hdfs_path: Optional[str] = None, + global_step: int = 0, + max_ckpt_to_keep: Optional[int] = None, + **kwargs, + ) -> None: + """Save checkpoint.""" + if self._is_offload_param: + for module in self.module: + load_fsdp_model_to_gpu(module) + + # Override TorchTitan's folder to use verl's path + parent_dir = os.path.dirname(local_path) + self.checkpointer.folder = parent_dir + + if max_ckpt_to_keep is not None: + self.checkpointer.keep_latest_k = max_ckpt_to_keep + + self.checkpointer.save(curr_step=global_step) + + torch.distributed.barrier() + if self._is_offload_param: + for module in self.module: + offload_fsdp_model_to_cpu(module) + + def load_checkpoint( + self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs + ) -> None: + """Load checkpoint.""" + if self._is_offload_param: + for module in self.module: + load_fsdp_model_to_gpu(module) + + # Override TorchTitan's folder to use verl's path + parent_dir = os.path.dirname(local_path) + self.checkpointer.folder = parent_dir + + # Extract step number from path (verl uses global_step_N format) + match = re.search(r"global_step_(\d+)", local_path) + if match: + step = int(match.group(1)) + self.checkpointer.load(step=step) + else: + # Fallback to latest + self.checkpointer.load(step=-1) + + torch.distributed.barrier() + if self._is_offload_param: + for module in self.module: + offload_fsdp_model_to_cpu(module) + + if self._is_offload_optimizer: + offload_fsdp_optimizer(self.optimizer) + + def get_per_tensor_param(self, **kwargs): + for module in self.module: + load_fsdp_model_to_gpu(module) + + # Collect state dicts from all model parts + params = {} + for module in self.module: + module_params = get_model_state_dict(module) + params.update(module_params) + + if self._is_offload_param: + for module in self.module: + offload_fsdp_model_to_cpu(module) + + # Convert TorchTitan key names to HuggingFace key names (expected by vLLM) + sd_adapter = self.checkpointer.sd_adapter + if sd_adapter is not None: + params = sd_adapter.to_hf(params) + + # When weight tying is enabled, the sd_adapter skips lm_head.weight during + # to_hf() conversion (since it's the same tensor as embed_tokens.weight in + # the torchtitan model). But vLLM needs lm_head.weight explicitly, so we + # add it back as a reference to embed_tokens.weight. + if "model.embed_tokens.weight" in params and "lm_head.weight" not in params: + params["lm_head.weight"] = params["model.embed_tokens.weight"] + + device = get_device_id() # used when fsdp2 set cpu_offload_policy + # TODO: cast fp32 to bf16 to reduce weight sync overhead, need more fine-grained control, e.g MoE gate + per_tensor_param = ( + ( + name, + param.to(device, non_blocking=True).full_tensor().to(torch.bfloat16, non_blocking=True) + if isinstance(param, DTensor) + else param, + ) + for name, param in params.items() + ) + # TODO: support Torchtitan PEFT + return per_tensor_param, None + + +class EngineEvalModeCtx(BaseEngineCtx): + def __init__(self, engine: TorchTitanEngine, **kwargs): + super().__init__(engine=engine, mode="eval", **kwargs) + + def __enter__(self): + assert isinstance(self.engine, TorchTitanEngine) + super().__enter__() + for module in self.engine.module: + module.eval() + + def __exit__(self, exc_type, exc_value, traceback): + assert isinstance(self.engine, TorchTitanEngine) + + # Reshard the root FSDP module + if self.engine.engine_config.data_parallel_shard_size > 1: + for module in self.engine.module: + module.reshard() + + super().__exit__(exc_type, exc_value, traceback) + + +class EngineTrainModeCtx(BaseEngineCtx): + def __init__(self, engine: TorchTitanEngine, **kwargs): + super().__init__(engine=engine, mode="train", **kwargs) + + def __enter__(self): + assert isinstance(self.engine, TorchTitanEngine) + super().__enter__() + for module in self.engine.module: + module.train() + + def __exit__(self, exc_type, exc_value, traceback): + assert isinstance(self.engine, TorchTitanEngine) + self.engine.optimizer_zero_grad() + super().__exit__(exc_type, exc_value, traceback) + + +@EngineRegistry.register(model_type="language_model", backend=["torchtitan"], device=["cuda", "npu"]) +class TorchTitanEngineWithLMHead(TorchTitanEngine): + """TorchTitan engine implementation for language models with LM head.""" + + def prepare_model_inputs(self, micro_batch: TensorDict): + use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True) + pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) + assert pad_mode == DatasetPadMode.NO_PADDING, f"pad_mode {pad_mode} not supported" + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch.get("multi_modal_inputs", [])) + input_ids = micro_batch["input_ids"] + position_ids = micro_batch["position_ids"] + output_args = {} + + if use_remove_padding: + input_ids = input_ids.values().unsqueeze(0) + if position_ids.dim() == 3: + position_ids = position_ids.values().unsqueeze(1) + else: + position_ids = position_ids.values().unsqueeze(0) + + labels = torch.roll(input_ids, shifts=-1, dims=1) + attn_type = self.trainer.model_args.attn_type + attention_mask = get_attention_masks( + input_batch=input_ids, + positions=position_ids, + attn_type=attn_type, + ) + else: + loss_mask = micro_batch["loss_mask"] + pad_token_id = tu.get_non_tensor_data(data=micro_batch, key="pad_token_id", default=0) + batch_size = micro_batch.batch_size[0] + max_seq_len = max(input_ids.offsets().diff()) + + labels = torch.roll(input_ids.values(), shifts=-1, dims=0) + input_ids = torch.nested.to_padded_tensor( + input_ids, padding=pad_token_id, output_size=(batch_size, max_seq_len) + ) + + if position_ids.dim() == 3: + position_ids = torch.nested.to_padded_tensor( + position_ids, padding=0, output_size=(batch_size, 4, max_seq_len) + ).transpose(0, 1) + else: + position_ids = torch.nested.to_padded_tensor( + position_ids, padding=0, output_size=(batch_size, max_seq_len) + ) + + attention_mask_list = [torch.ones_like(t, dtype=torch.int32) for t in loss_mask] + attention_mask = torch.nested.as_nested_tensor(attention_mask_list, layout=torch.jagged) + attention_mask = torch.nested.to_padded_tensor( + attention_mask, padding=0, output_size=(batch_size, max_seq_len) + ) + + extra_inputs = { + "positions": position_ids, + } + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_kwargs are. + extra_kwargs: dict[str, Any] = {"attention_masks": attention_mask} + if self.parallel_dims.cp_enabled: + input_ids, labels, extra_kwargs = prepare_context_parallel_input( + input_ids, + labels, + extra_kwargs, + self.parallel_dims.get_mesh("cp"), + self.trainer.device, + self.trainer.job_config.parallelism.context_parallel_load_balancer, + ) + + # TODO(jessicazhong): multimodal is not yet supported for Torchtitan engine + extra_inputs.update(multi_modal_inputs) + output_args["labels"] = labels + return input_ids, extra_inputs, extra_kwargs, output_args + + def prepare_model_outputs(self, logits, output_args, micro_batch: TensorDict): + use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True) + pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) + assert pad_mode == DatasetPadMode.NO_PADDING, f"pad_mode {pad_mode} not supported" + + temperature = micro_batch["temperature"] + calculate_entropy = tu.get_non_tensor_data(data=micro_batch, key="calculate_entropy", default=False) + labels = output_args["labels"] + model_output = {} + + input_ids = micro_batch["input_ids"] + cu_seqlens = input_ids.offsets() + if use_remove_padding: + labels = labels.squeeze(0) + logits_rmpad = logits.squeeze(0) + # PyTorch's autograd doesn't allow in-place modification of views when gradients need to flow back + logits_rmpad = logits_rmpad / temperature + + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=labels, + inplace_backward=inplace_backward, + ) + + if calculate_entropy: + if not self.engine_config.entropy_checkpointing: + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) + else: + entropy_rmpad = torch.utils.checkpoint.checkpoint(self.compute_entropy_from_logits, logits_rmpad) + + log_probs = torch.nested.nested_tensor_from_jagged(log_probs.squeeze(0), cu_seqlens) + if calculate_entropy: + entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens) + else: + logits.div_(temperature) + if calculate_entropy: + if not self.engine_config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) + else: + entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits) + + seq_lengths = cu_seqlens.diff() + starts = torch.zeros_like(seq_lengths, dtype=torch.int64) + logits = torch.nested.narrow(logits, 1, starts, seq_lengths, layout=torch.jagged) + logits_rmpad = torch.cat([t for t in logits.unbind()]) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=output_args["labels"]) + log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens) + if calculate_entropy: + entropy = torch.nested.narrow(entropy, 1, starts, seq_lengths, layout=torch.jagged) + entropy_rmpad = torch.cat([t for t in entropy.unbind()]) + entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens) + + model_output["log_probs"] = log_probs + if calculate_entropy: + model_output["entropy"] = entropy + + return model_output + + def forward_step(self, micro_batch: TensorDict, loss_function, forward_only): + device_name = get_device_name() + micro_batch = micro_batch.to(get_device_id()) + input_ids, extra_inputs, extra_kwargs, output_args = self.prepare_model_inputs(micro_batch=micro_batch) + + with torch.autocast(device_type=device_name, dtype=torch.bfloat16): + logits = self.model_forward_step(inputs=input_ids, extra_inputs=extra_inputs, extra_kwargs=extra_kwargs) + + model_output = self.prepare_model_outputs(logits=logits, output_args=output_args, micro_batch=micro_batch) + + if loss_function is not None: + loss, metrics = loss_function( + model_output=model_output, data=micro_batch, dp_group=self.get_data_parallel_group() + ) + else: + assert forward_only, "forward_only must be True when loss_function is None" + loss = torch.tensor(1.0, device=device_name) + metrics = {} + + output = { + "model_output": model_output, + "loss": loss.detach().item(), + "metrics": metrics, + } + + return loss, output diff --git a/verl/workers/engine/torchtitan/utils.py b/verl/workers/engine/torchtitan/utils.py new file mode 100644 index 00000000000..686fb94e6b2 --- /dev/null +++ b/verl/workers/engine/torchtitan/utils.py @@ -0,0 +1,213 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import torch +import torch.nn as nn +from torch.distributed._composable.fsdp import FSDPModule +from torch.nn.attention.flex_attention import _mask_mod_signature, and_masks +from torchtitan.models.attention import VarlenMetadata, create_attention_mask, get_causal_mask_mod +from torchtitan.protocols.model import AttentionMasksType + +logger = logging.getLogger(__name__) + +# Mapping from HuggingFace model_type to torchtitan model name. +# Torchtitan models not mapped here: +# - flux: diffusion model, not applicable to verl's RL/SFT workflows +# - llama3_ft: fault-tolerant variant of llama3, same HF models (mapped via "llama") +_HF_MODEL_TYPE_TO_TORCHTITAN_NAME = { + "qwen2": "qwen3", + "qwen3": "qwen3", + "qwen2_moe": "qwen3", + "qwen3_moe": "qwen3", + "llama": "llama3", + "llama4": "llama4", + "deepseek_v3": "deepseek_v3", + "gpt_oss": "gpt_oss", +} + + +def derive_torchtitan_name_and_flavor(hf_config) -> tuple[str, str]: + """Derive torchtitan model name and flavor from a HuggingFace config. + + The name is mapped from ``hf_config.model_type``. The flavor is found by + matching architecture parameters (dim, n_layers, vocab_size) against the + known flavors registered in the torchtitan TrainSpec. + + Args: + hf_config: A HuggingFace AutoConfig object. + + Returns: + A ``(name, flavor)`` tuple. + + Raises: + ValueError: If model_type is unsupported or no matching flavor is found. + """ + import torchtitan.protocols.train_spec as train_spec_module + + model_type = getattr(hf_config, "model_type", None) + if model_type is None: + raise ValueError("HuggingFace config does not have 'model_type' field") + + name = _HF_MODEL_TYPE_TO_TORCHTITAN_NAME.get(model_type) + if name is None: + raise ValueError( + f"Cannot derive torchtitan model name from HF model_type '{model_type}'. " + f"Supported types: {list(_HF_MODEL_TYPE_TO_TORCHTITAN_NAME.keys())}." + ) + + train_spec = train_spec_module.get_train_spec(name) + + hidden_size = hf_config.hidden_size + num_layers = hf_config.num_hidden_layers + vocab_size = hf_config.vocab_size + + for flavor_name, model_args in train_spec.model_args.items(): + if ( + getattr(model_args, "dim", None) == hidden_size + and getattr(model_args, "n_layers", None) == num_layers + and getattr(model_args, "vocab_size", None) == vocab_size + ): + logger.info( + f"Auto-derived torchtitan name='{name}', flavor='{flavor_name}' from HF model_type='{model_type}'" + ) + return name, flavor_name + + raise ValueError( + f"No matching torchtitan flavor found for model_type='{model_type}' " + f"(hidden_size={hidden_size}, num_hidden_layers={num_layers}, " + f"vocab_size={vocab_size}). " + f"Available flavors for '{name}': {list(train_spec.model_args.keys())}." + ) + + +def enable_fsdp_gradient_division(model: nn.Module, dp_size: int) -> None: + """ + Re-enable FSDP's automatic gradient division. + + TorchTitan calls disable_fsdp_gradient_division() which sets gradient_divide_factor=1.0. + This re-enables it by setting the factor to the specified dp_size, so gradients are + averaged across FSDP ranks. This is needed for verl's loss scaling (loss * dp_size) + to work correctly. + + Args: + model: The model (or model part) to enable gradient division on. + dp_size: The data parallel size to use as the gradient divide factor. + """ + + for module in model.modules(): + if isinstance(module, FSDPModule): + module.set_gradient_divide_factor(float(dp_size)) + + +def get_attention_masks( + input_batch: torch.Tensor, + positions: torch.Tensor, + attn_type: str, +) -> AttentionMasksType: + match attn_type: + case "flex": + return _get_flex_attention_masks( + input_batch, + positions, + ) + case "varlen": + return _create_varlen_metadata_for_document( + input_batch, + positions, + ) + case _: + raise TypeError("Only varlen and flex attn masks are supported") + + +def _get_document_mask_mod(positions: torch.Tensor) -> _mask_mod_signature: + # Detect boundaries from position resets + first_dummy_value = positions[:, :1] - 1 + position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) + sequence_indices = (position_diff != 1).cumsum(-1) # [batch, seq] + + def document_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: + return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx] + + return document_mask + + +def _get_flex_attention_masks( + input_batch: torch.Tensor, + positions: torch.Tensor, +) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + B = input_batch.shape[0] + mask_mods.append(_get_document_mask_mod(positions=positions)) + return create_attention_mask(and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]) + + +def _create_varlen_metadata_for_document(input_batch: torch.Tensor, positions: torch.Tensor) -> VarlenMetadata: + """ + Creates cumulative sequence length indices needed for variable length attention + + Args: + input_batch: Input token IDs with shape [batch, seq]. + positions: Position IDs with shape [batch, seq]. Boundaries detected where + position diff != 1 (i.e., position resets). + + Returns: + VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len + """ + batch_size, seq_len = input_batch.shape + device = input_batch.device + + # Detect boundaries from position resets (where diff != 1) + first_dummy_value = positions[:, :1] - 1 + position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) + # boundary_mask[b, i] is True if position i starts a new document + boundary_mask = position_diff != 1 # [batch, seq] + boundary_mask[:, 0] = True + + cu_seqlens_list, all_seq_lengths = [], [] + offset = 0 + + for b in range(batch_size): + # Find positions where new documents start + boundary_positions = boundary_mask[b].nonzero(as_tuple=True)[0].to(torch.int32) + sample_cu_seqlens = torch.cat( + [ + boundary_positions, + torch.tensor([seq_len], dtype=torch.int32, device=device), + ] + ) + sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens) + + seq_lengths = torch.diff(sample_cu_seqlens) + all_seq_lengths.append(seq_lengths) + + cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset + cu_seqlens_list.append(cu_seqlens_adjusted) + + offset += seq_len + + packed_cu_seqlens = torch.cat(cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]) + + max_seqlen = 0 + if len(all_seq_lengths) > 0: + all_seq_lengths = torch.cat(all_seq_lengths) + # device to host sync but only done once per model forward + max_seqlen = all_seq_lengths.max().item() + + return VarlenMetadata( + cu_seq_q=packed_cu_seqlens, + cu_seq_k=packed_cu_seqlens, + max_q=max_seqlen, + max_k=max_seqlen, + ) diff --git a/verl/workers/engine/utils.py b/verl/workers/engine/utils.py index 0484b0d2a0c..b3261dafde3 100644 --- a/verl/workers/engine/utils.py +++ b/verl/workers/engine/utils.py @@ -70,7 +70,10 @@ def prepare_micro_batches( use_dynamic_bsz = tu.get_non_tensor_data(data=data, key="use_dynamic_bsz", default=True) sp_size = tu.get_non_tensor_data(data=data, key="sp_size", default=1) + force_group_size = tu.get_non_tensor_data(data=data, key="force_group_size", default=1) + if use_dynamic_bsz: + assert force_group_size == 1, "force_group_size is not supported when use_dynamic_bsz is True" assert "max_token_len_per_gpu" in data.keys(), "max_token_len_per_gpu must be set when use_dynamic_bsz is True" max_token_len_per_gpu = data["max_token_len_per_gpu"] max_token_len = max_token_len_per_gpu * sp_size @@ -84,8 +87,12 @@ def prepare_micro_batches( use_dynamic_bsz_balance=use_dynamic_bsz_balance, ) else: + total_data_size = len(data) micro_batch_size_per_gpu = data["micro_batch_size_per_gpu"] - micro_batches = tu.chunk_tensordict(data, len(data) // micro_batch_size_per_gpu) + assert total_data_size % (force_group_size * micro_batch_size_per_gpu) == 0, ( + "data size must be divisible by force_group_size * micro_batch_size_per_gpu" + ) + micro_batches = tu.chunk_tensordict(data, total_data_size // (micro_batch_size_per_gpu * force_group_size)) batch_idx_list = None return micro_batches, batch_idx_list diff --git a/verl/workers/engine/veomni/transformer_impl.py b/verl/workers/engine/veomni/transformer_impl.py index 5aaf41f9669..6d0e50806dc 100644 --- a/verl/workers/engine/veomni/transformer_impl.py +++ b/verl/workers/engine/veomni/transformer_impl.py @@ -35,8 +35,11 @@ from verl.utils.fsdp_utils import fsdp_version from verl.utils.model import convert_weight_keys from verl.utils.profiler import log_gpu_memory_usage +from verl.utils.ulysses import ( + get_ulysses_sequence_parallel_group, + set_ulysses_sequence_parallel_group, +) from verl.workers.config import HFModelConfig, VeOmniEngineConfig, VeOmniOptimizerConfig -from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from ..base import BaseEngineCtx, EngineRegistry from ..fsdp.transformer_impl import FSDPEngine, FSDPEngineWithLMHead @@ -79,14 +82,27 @@ def __init__( self.data_parallel_mode = "fsdp2" self.rank = dist.get_rank() + fsdp_size = self.engine_config.fsdp_size + world_size = dist.get_world_size() + dp_size = world_size // self.engine_config.ulysses_parallel_size + + if fsdp_size < 0 or fsdp_size >= dp_size: + data_parallel_replicate_size = 1 + data_parallel_shard_size = dp_size + else: + if dp_size % fsdp_size != 0: + raise ValueError( + f"Data parallel size ({dp_size}) must be divisible by fsdp_size ({fsdp_size}). " + "Please adjust your parallel configuration." + ) + data_parallel_replicate_size = dp_size // fsdp_size + data_parallel_shard_size = fsdp_size + parallel_state.init_parallel_state( - dp_size=self.engine_config.data_parallel_size, - dp_replicate_size=self.engine_config.data_parallel_replicate_size, - dp_shard_size=self.engine_config.data_parallel_shard_size, - tp_size=self.engine_config.tensor_parallel_size, + dp_size=dp_size, + dp_replicate_size=data_parallel_replicate_size, + dp_shard_size=data_parallel_shard_size, ep_size=self.engine_config.expert_parallel_size, - pp_size=self.engine_config.pipeline_parallel_size, - cp_size=self.engine_config.context_parallel_size, ulysses_size=self.engine_config.ulysses_parallel_size, dp_mode=self.data_parallel_mode, ) @@ -104,9 +120,9 @@ def __init__( self.ulysses_sequence_parallel_size = self.engine_config.ulysses_parallel_size if self.use_ulysses_sp: - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(parallel_state.get_parallel_state().device_mesh) + self.ulysses_parallel_group = parallel_state.get_parallel_state().device_mesh["sp"].get_group() else: - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(None) + self.ulysses_parallel_group = None if self.engine_config.entropy_from_logits_with_chunking: entropy_from_logits = verl_F.entropy_from_logits_with_chunking @@ -438,12 +454,13 @@ def __init__(self, engine: VeOmniEngine, **kwargs): def __enter__(self): assert isinstance(self.engine, VeOmniEngine) super().__enter__() - self.engine.ulysses_sharding_manager.__enter__() + self.prev_sp_group = get_ulysses_sequence_parallel_group() + set_ulysses_sequence_parallel_group(self.engine.ulysses_parallel_group) self.engine.module.train() def __exit__(self, exc_type, exc_value, traceback): assert isinstance(self.engine, VeOmniEngine) - self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback) + set_ulysses_sequence_parallel_group(self.prev_sp_group) # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module @@ -463,14 +480,15 @@ def __init__(self, engine: VeOmniEngine, **kwargs): def __enter__(self): assert isinstance(self.engine, VeOmniEngine) super().__enter__() - self.engine.ulysses_sharding_manager.__enter__() + self.prev_sp_group = get_ulysses_sequence_parallel_group() + set_ulysses_sequence_parallel_group(self.engine.ulysses_parallel_group) # TODO: Switch to eval mode after Integrating the CI environment # VeOmni (ref: https://github.com/ByteDance-Seed/VeOmni/pull/421) self.engine.module.train() def __exit__(self, exc_type, exc_value, traceback): assert isinstance(self.engine, VeOmniEngine) - self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback) + set_ulysses_sequence_parallel_group(self.prev_sp_group) self.engine.optimizer_zero_grad() super().__exit__(exc_type, exc_value, traceback) @@ -492,6 +510,19 @@ class OmniSequenceShardCollator: metadata={"help": "features to slice sequence dimension."}, ) + # features to padding sequence dimension + padding_features: dict[str, int] = field( + default_factory=lambda: { + "pixel_values": 0, + }, + metadata={"help": "features to padding sequence dimension."}, + ) + + # padding scale for padding features + padding_scale: dict[str, int] = field( + default_factory=lambda: {"pixel_values": 4}, metadata={"help": "padding scale for padding features."} + ) + def __post_init__(self): self.sp_size = parallel_state.get_parallel_state().sp_size self.sp_rank = parallel_state.get_parallel_state().sp_rank @@ -501,7 +532,35 @@ def sp_slice(self, feature: torch.Tensor, dim: int = -1) -> dict[str, "torch.Ten sp_chunk_size = (seq_length + self.sp_size - 1) // self.sp_size return feature.narrow(dim, self.sp_rank * sp_chunk_size, sp_chunk_size) + def sp_padding( + self, tensor: "torch.Tensor", dim: int = -1, pad_value: int = 0, pad_scale: int = 1 + ) -> "torch.Tensor": + """ + Pads a tensor with pad_length to aligns tensor with sp size. + """ + seq_length = tensor.size(dim) + scale_sp_size = self.sp_size * pad_scale + + sp_chunk_size = (seq_length + scale_sp_size - 1) // scale_sp_size + pad_size = sp_chunk_size * scale_sp_size - seq_length + if pad_size == 0: + return tensor + + pad_shape = list(tensor.shape) + pad_shape[dim] = pad_size + pad = torch.full(pad_shape, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device) + return torch.cat((tensor, pad), dim=dim) + def __call__(self, batch: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]: + for key in batch.keys(): + if key in self.padding_features.keys(): + batch[key] = self.sp_padding( + batch[key], + dim=self.sp_slice_features.get(key, -1), + pad_value=self.padding_features[key], + pad_scale=self.padding_scale.get(key, 1), + ) + # sp slice for key in batch.keys(): if key in self.sp_slice_features.keys(): diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index 6f8029600ea..3d129d1fe2f 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -179,17 +179,20 @@ def _postprocess_output(self, output, *, global_token_num, delta_time, forward_o # Here each metric in metrics can be a list (micro-batch metrics) or a singleton # we should always sum the loss of each micro-batch as we scale by global_bsz/global_token loss = torch.sum(torch.tensor(output.pop("loss"), device=self.device_name)) - torch.distributed.all_reduce( - loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group() - ) + dp_group = self.engine.get_data_parallel_group() + if dp_group is not None: + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG, group=dp_group) loss = loss.item() # For grad_norm, we do not perform all reduce because it is already been done when clipping grad grad_norm = metrics.pop("grad_norm", None) lr = metrics.pop("lr", None) - # For other metrics, we perform all gather in dp group - final_metrics = allgather_dict_into_dict(data=metrics, group=self.engine.get_data_parallel_group()) + # For other metrics, we perform all gather in dp group (only if DP > 1) + if dp_group is not None: + final_metrics = allgather_dict_into_dict(data=metrics, group=dp_group) + else: + final_metrics = metrics final_metrics["loss"] = loss if grad_norm is not None: final_metrics["grad_norm"] = grad_norm @@ -580,6 +583,9 @@ def init_model(self): backend, is_master=(torch.distributed.get_rank() == 0), bucket_size=bucket_size, **engine_kwargs ) + # Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo + aggressive_empty_cache(force_sync=True) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref")) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") @_with_routing_replay_flag(enabled=False) @@ -621,11 +627,10 @@ async def update_weights(self): - after update_weights: rollout should be in wake_up mode. 2. For async training with disaggregated trainer and rollout, send_weights only by checkpoint engine. """ - assert self.checkpoint_engine is not None # 0. send_weights only for async training with disaggregated trainer and rollout if self.config.rollout.checkpoint_engine.backend != "naive": - per_tensor_param, _ = self.engine.get_per_tensor_param() + per_tensor_param, _ = self.actor.engine.get_per_tensor_param() await self.checkpoint_engine.send_weights(per_tensor_param) return diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 1425091432e..24b46e13c1b 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -28,6 +28,7 @@ import torch.distributed as dist from codetiming import Timer from omegaconf import DictConfig, OmegaConf, open_dict +from omegaconf.errors import ConfigAttributeError from peft import LoraConfig, TaskType, get_peft_model from safetensors.torch import save_file from torch.distributed.device_mesh import init_device_mesh @@ -81,6 +82,9 @@ from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max from verl.utils.py_functional import convert_to_regular_types + +# QAT support +from verl.utils.qat import apply_qat, enable_qat_fuse from verl.utils.ray_utils import get_event_loop from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig from verl.workers.config.optimizer import build_optimizer @@ -275,6 +279,52 @@ def __init__(self, config: DictConfig, role: str, **kwargs): self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + def _init_qat_config(self): + """Initialize QAT configuration from actor.qat.""" + try: + self.qat_config = self.config.actor.qat + self._qat_enabled = self.qat_config.enable + if self._qat_enabled: + logger.info( + f"QAT enabled: mode={self.qat_config.mode}, config_path={self.qat_config.quantization_config_path}" + ) + except (AttributeError, KeyError, ConfigAttributeError): + # QAT config not provided, disable QAT + self._qat_enabled = False + self.qat_config = None + + def _restore_w4a4_input_scales(self, model, model_path): + """Restore input_global_scale and input_amax from checkpoint for W4A4 mode.""" + import glob + + from safetensors import safe_open + + safetensor_files = glob.glob(f"{model_path}/model*.safetensors") + loaded_count = 0 + + for sf_path in safetensor_files: + with safe_open(sf_path, framework="pt") as f: + for key in f.keys(): + if "input_global_scale" in key: + module_path = key.replace(".input_global_scale", "") + amax_key = f"{module_path}.input_amax" + + module = model + for part in module_path.split("."): + module = getattr(module, part) + + scale_val = f.get_tensor(key) + val = scale_val.item() if scale_val.numel() == 1 else scale_val.max().item() + module.input_global_scale.fill_(val) + + amax_val = f.get_tensor(amax_key) + amax = amax_val.item() if amax_val.numel() == 1 else amax_val.max().item() + module.input_amax.fill_(amax) + loaded_count += 1 + + if self.rank == 0: + logger.info(f"[W4A4] Loaded {loaded_count} input scales from checkpoint") + def _build_model_optimizer( self, model_path, @@ -485,6 +535,13 @@ def _build_model_optimizer( if self.rank == 0: print("[actor model] No vision tower found.") + # Apply QAT before FSDP wrapping (actor only) + if role == "actor" and self._qat_enabled: + actor_module = apply_qat(actor_module, self.qat_config) + enable_qat_fuse(actor_module) + if self.qat_config.mode == "w4a4": + self._restore_w4a4_input_scales(actor_module, self.config.model.path) + torch.distributed.barrier() if self.rank == 0: @@ -505,6 +562,9 @@ def _build_model_optimizer( mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + # Store param_dtype for QAT quantizer + self._param_dtype = param_dtype + auto_wrap_policy = get_fsdp_wrap_policy( module=actor_module, config=fsdp_config.get("wrap_policy", None), @@ -740,6 +800,23 @@ async def rollout_mode(self): for name, param in params.items() ) + # QAT: quantize weights before sending to vLLM + if self._qat_enabled: + from verl.utils.qat.quantizer import QATQuantizer + + quantizer = QATQuantizer( + mode=self.qat_config.mode, + group_size=self.qat_config.group_size, + ignore_patterns=self.qat_config.ignore_patterns, + device=torch.device(get_device_id()), + param_dtype=self._param_dtype, + ) + per_tensor_param = quantizer.quantize_with_fusion( + per_tensor_param, + target_device=torch.device("cpu"), + ) + aggressive_empty_cache(force_sync=True) + if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) log_gpu_memory_usage("After resume weights", logger=logger) @@ -774,6 +851,9 @@ def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) + # Initialize QAT config before _build_model_optimizer + self._init_qat_config() + override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) use_remove_padding = self.config.model.get("use_remove_padding", False) use_shm = self.config.model.get("use_shm", False) @@ -901,6 +981,9 @@ def init_model(self): checkpoint_config=checkpoint_contents, ) + # Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo + aggressive_empty_cache(force_sync=True) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 3baf4020810..f78dcde56e2 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -155,7 +155,6 @@ def _init_hf_config_and_tf_config( if enable_mtp: assert hf_config.num_nextn_predict_layers > 0, "MTP requires at least one nextn_predict_layer" assert megatron_config.use_mbridge, "MTP requires use_mbridge to be True" - assert megatron_config.vanilla_mbridge, "MTP requires vanilla_mbridge to be True" override_transformer_config["mtp_loss_scaling_factor"] = self.config.model.mtp.mtp_loss_scaling_factor else: if hasattr(hf_config, "num_nextn_predict_layers"): @@ -199,6 +198,10 @@ def _init_hf_config_and_tf_config( # In case of invalid overrides, we need to make sure some critical params are set correctly provider.params_dtype = dtype + # Ensure dtype settings propagate to Megatron-Bridge/TE + provider.fp16 = fp16 + provider.bf16 = bf16 + # Pass distributed info provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size @@ -670,7 +673,8 @@ def init_model(self): if not self.config.actor.megatron.use_mbridge: self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) - get_torch_device().empty_cache() + # Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo + aggressive_empty_cache(force_sync=True) log_gpu_memory_usage("After init_model finish", logger=logger) async def rollout_mode(self): diff --git a/verl/workers/rollout/base.py b/verl/workers/rollout/base.py index 31d5b9736b7..c8038606f1f 100644 --- a/verl/workers/rollout/base.py +++ b/verl/workers/rollout/base.py @@ -34,6 +34,8 @@ def __init__( config: RolloutConfig, model_config: HFModelConfig, device_mesh: DeviceMesh, + *args, + **kwargs, ): self.config = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index bf83ac7d05f..ed327c25293 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -18,6 +18,7 @@ from enum import Enum from typing import Any, Callable, Optional +import ray from omegaconf import DictConfig from pydantic import BaseModel from ray.actor import ActorHandle @@ -30,6 +31,11 @@ logger = logging.getLogger(__file__) +# Max number of concurrent calls to the methods of Rollout, +# excluding calls to generate method. +CONTROL_METHOD_CONCURRENCY = 16 + + class TokenOutput(BaseModel): token_ids: list[int] """response token ids""" @@ -91,7 +97,7 @@ def __init__( is_reward_model: bool = False, ) -> None: self.replica_rank = replica_rank - self.config = omega_conf_to_dataclass(config) + self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = model_config self.world_size = ( @@ -200,10 +206,18 @@ async def init_standalone(self): self.workers = worker_group.workers await self.launch_servers() - @abstractmethod def get_ray_class_with_init_args(self) -> RayClassWithInitArgs: """Get rollout worker actor class for colocated and standalone mode.""" - raise NotImplementedError + from verl.checkpoint_engine.base import CheckpointEngineWorker + + rollout_worker_actor_cls = ray.remote(CheckpointEngineWorker) + + return RayClassWithInitArgs( + cls=rollout_worker_actor_cls, + rollout_config=self.config, + model_config=self.model_config, + replica_rank=self.replica_rank, + ) @abstractmethod async def launch_servers(self): @@ -220,6 +234,12 @@ def server_handle(self) -> ActorHandle: """Get rollout server handle for Token-in-token-out generation.""" return self._server_handle + @property + def max_concurrency(self) -> int: + # 1000 is Ray's default max_concurrency for async execution. + # Add some margin to account for control method call. + return max(1000, self.config.max_num_seqs + CONTROL_METHOD_CONCURRENCY) + def rollout_worker_use_gpu(self) -> bool: return True diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index ab8ee461dea..964fe97df1f 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -39,15 +39,14 @@ ) from sglang.srt.managers.tokenizer_manager import ServerStatus -from verl.single_controller.ray import RayClassWithInitArgs from verl.utils.config import omega_conf_to_dataclass from verl.utils.device import get_visible_devices_keyword from verl.utils.net_utils import get_free_port, is_valid_ipv6_address from verl.utils.profiler import DistProfiler, build_sglang_profiler_args from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput -from verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter, _set_envs_and_config -from verl.workers.rollout.utils import get_max_position_embeddings, run_unvicorn +from verl.workers.rollout.sglang_rollout.sglang_rollout import _set_envs_and_config +from verl.workers.rollout.utils import get_max_position_embeddings, run_uvicorn logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) @@ -123,17 +122,19 @@ def __init__( profiler_config = None self.profiler_controller = DistProfiler(self.replica_rank, config=profiler_config, tool_config=tool_config) - # used for NCCL process group - if self.node_rank == 0: + # For multi-node, we need dist_init_addr so nodes can coordinate NCCL init. + # For single-node, let SGLang handle port selection internally via nccl_port, + # which also avoids port conflicts. + self._master_address = None + self._master_port = None + self._master_sock = None + if self.nnodes > 1 and self.node_rank == 0: self._master_address = self._server_address - self._master_port, self._master_sock = get_free_port(self._server_address) + self._master_port, self._master_sock = get_free_port(self._server_address, with_alive_sock=True) logger.info( f"SGLangHttpServer, replica_rank: {self.replica_rank}, " f"master address: {self._master_address}, port: {self._master_port}" ) - else: - self._master_address = None - self._master_port = None def get_master_address(self): """Get master address and port for init NCCL process group.""" @@ -145,10 +146,13 @@ def get_server_address(self): return self._server_address, self._server_port async def launch_server(self, master_address: str = None, master_port: int = None): - if self.node_rank != 0: - assert master_address and master_port, "non-master node should provide master address and port" - self._master_address = master_address - self._master_port = master_port + if self.nnodes > 1: + if self.node_rank != 0: + assert master_address and master_port, "non-master node should provide master address and port" + self._master_address = master_address + self._master_port = master_port + else: + self._master_sock.close() engine_kwargs = self.config.get("engine_kwargs", {}).get("sglang", {}) or {} attention_backend = engine_kwargs.pop("attention_backend", None) @@ -167,11 +171,6 @@ async def launch_server(self, master_address: str = None, master_port: int = Non fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) else: raise ValueError(f"Currently only support fp8 quantization, got: {quantization}") - dist_init_addr = ( - f"[{self._master_address}]:{self._master_port}" - if is_valid_ipv6_address(self._master_address) - else f"{self._master_address}:{self._master_port}" - ) infer_tp = self.config.tensor_model_parallel_size * self.config.data_parallel_size args = { "model_path": self.model_config.local_path, @@ -186,7 +185,6 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "ep_size": self.config.expert_parallel_size, "node_rank": self.node_rank, "load_format": self.config.load_format, - "dist_init_addr": dist_init_addr, "nnodes": self.nnodes, "trust_remote_code": self.model_config.trust_remote_code, "max_running_requests": self.config.get("max_num_seqs", None), @@ -202,6 +200,16 @@ async def launch_server(self, master_address: str = None, master_port: int = Non **engine_kwargs, } + # Only set dist_init_addr for multi-node; for single-node, let SGLang + # handle port selection internally via nccl_port to avoid conflicts. + if self.nnodes > 1: + dist_init_addr = ( + f"[{self._master_address}]:{self._master_port}" + if is_valid_ipv6_address(self._master_address) + else f"{self._master_address}:{self._master_port}" + ) + args["dist_init_addr"] = dist_init_addr + if self.config.prometheus.enable: if self.config.prometheus.served_model_name: # Extract model name from path if it's a full path @@ -278,7 +286,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non add_prometheus_middleware(app) - self._server_port, self._server_task = await run_unvicorn(app, server_args, self._server_address) + self._server_port, self._server_task = await run_uvicorn(app, server_args, self._server_address) self.tokenizer_manager.server_status = ServerStatus.Up async def wake_up(self): @@ -413,9 +421,6 @@ async def stop_profile(self): await self.tokenizer_manager.stop_profile() -_rollout_worker_actor_cls = ray.remote(ServerAdapter) - - class SGLangReplica(RolloutReplica): def __init__( self, @@ -428,16 +433,6 @@ def __init__( super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model) self.server_class = ray.remote(SGLangHttpServer) - def get_ray_class_with_init_args(self) -> RayClassWithInitArgs: - """Get rollout worker actor class for colocated and standalone mode.""" - worker_dict_cls = RayClassWithInitArgs( - cls=_rollout_worker_actor_cls, - config=self.config, - model_config=self.model_config, - device_mesh=None, - ) - return worker_dict_cls - async def launch_servers(self): """Launch http server in each node.""" assert len(self.workers) == self.world_size, ( @@ -496,6 +491,7 @@ async def launch_servers(self): ), runtime_env={"env_vars": {f"RAY_EXPERIMENTAL_NOSET_{visible_devices_keyword}": "1"}}, name=name, + max_concurrency=self.max_concurrency, ).remote( config=self.config, model_config=self.model_config, @@ -510,7 +506,9 @@ async def launch_servers(self): self.servers.append(server) # launch http server in each node - master_address, master_port = await self.servers[0].get_master_address.remote() + master_address, master_port = None, None + if self.nnodes > 1: + master_address, master_port = await self.servers[0].get_master_address.remote() await asyncio.gather( *[ server.launch_server.remote(master_address=master_address, master_port=master_port) diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 2be15fc5b05..3048ab60148 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -98,6 +98,7 @@ def __init__( config: RolloutConfig, model_config: HFModelConfig, device_mesh: DeviceMesh, + replica_rank: int = -1, ): if config.get("quantization", None) == "fp8": import sglang @@ -120,7 +121,10 @@ def __init__( rank = int(os.environ["RANK"]) local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"]) rollout_world_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size - self.replica_rank = rank // rollout_world_size + if replica_rank == -1: + self.replica_rank = rank // rollout_world_size + else: + self.replica_rank = replica_rank self.rollout_rank = rank % rollout_world_size self.node_rank = self.rollout_rank // local_world_size self.local_rank = self.rollout_rank % local_world_size diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 97e255ad68c..195e798a6cb 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -23,13 +23,13 @@ from ray.util import placement_group_table from ray.util.placement_group import PlacementGroup -from verl.single_controller.ray import RayClassWithInitArgs, SubRayResourcePool +from verl.single_controller.ray import SubRayResourcePool from verl.utils.config import omega_conf_to_dataclass from verl.utils.net_utils import is_valid_ipv6_address from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput from verl.workers.rollout.trtllm_rollout.trtllm_rollout import ServerAdapter -from verl.workers.rollout.utils import get_max_position_embeddings, run_unvicorn +from verl.workers.rollout.utils import get_max_position_embeddings, run_uvicorn logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) @@ -110,7 +110,7 @@ def get_server_address(self): async def launch_server(self): from tensorrt_llm import AsyncLLM - from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig + from tensorrt_llm.llmapi import CapacitySchedulerPolicy, CudaGraphConfig, KvCacheConfig, SchedulerConfig from tensorrt_llm.serve import OpenAIServer assert self.config.pipeline_model_parallel_size == 1, "pipeline_model_parallel_size > 1 is not supported yet" @@ -162,7 +162,10 @@ async def launch_server(self): enable_padding=True, batch_sizes=self.config.cudagraph_capture_sizes, max_batch_size=0 if self.config.cudagraph_capture_sizes else self.config.max_num_seqs, - ) + ), + "scheduler_config": SchedulerConfig( + capacity_scheduler_policy=CapacitySchedulerPolicy.MAX_UTILIZATION, + ), } ) @@ -202,7 +205,7 @@ async def launch_server(self): ) app = trtllm_server.app - self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address) + self._server_port, self._server_task = await run_uvicorn(app, None, self._server_address) @resume_on_abort async def generate( @@ -280,9 +283,6 @@ async def report_device_ids(self) -> list[str]: ) -_rollout_worker_actor_cls = ray.remote(ServerAdapter) - - class TRTLLMReplica(RolloutReplica): def __init__( self, @@ -295,17 +295,6 @@ def __init__( super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model) self.node_ip = ray.util.get_node_ip_address().strip("[]") - def get_ray_class_with_init_args(self) -> RayClassWithInitArgs: - """Get rollout worker actor class for colocated and standalone mode.""" - worker_dict_cls = RayClassWithInitArgs( - cls=_rollout_worker_actor_cls, - config=self.config, - model_config=self.model_config, - device_mesh=None, - replica_rank=self.replica_rank, - ) - return worker_dict_cls - def rollout_worker_use_gpu(self) -> bool: return False @@ -392,6 +381,7 @@ async def launch_servers(self): ), runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}}, name=name, + max_concurrency=self.max_concurrency, ).remote( config=self.config, model_config=self.model_config, diff --git a/verl/workers/rollout/utils.py b/verl/workers/rollout/utils.py index 246ed3896b1..69c688dfa24 100644 --- a/verl/workers/rollout/utils.py +++ b/verl/workers/rollout/utils.py @@ -13,13 +13,10 @@ # limitations under the License. import asyncio import logging -import os import uvicorn from fastapi import FastAPI -from verl.utils.net_utils import get_free_port - logger = logging.getLogger(__file__) @@ -35,25 +32,42 @@ def get_max_position_embeddings(hf_config) -> int: return int(max_len) -async def run_unvicorn(app: FastAPI, server_args, server_address, max_retries=5) -> tuple[int, asyncio.Task]: - server_port, server_task = None, None +class _UvicornServerAutoPort(uvicorn.Server): + """Uvicorn Server that reports the system-assigned port when port=0.""" + + def __init__(self, config: uvicorn.Config) -> None: + super().__init__(config) + self.actual_port: int | None = None + self._startup_done: asyncio.Event = asyncio.Event() - for i in range(max_retries): + async def startup(self, sockets=None) -> None: try: - server_port, sock = get_free_port(server_address) - app.server_args = server_args - config = uvicorn.Config(app, host=server_address, port=server_port, log_level="warning") - server = uvicorn.Server(config) - server.should_exit = True - await server.serve() - server_task = asyncio.create_task(server.main_loop()) - break - except (OSError, SystemExit) as e: - logger.error(f"Failed to start HTTP server on port {server_port} at try {i}, error: {e}") - else: - logger.error(f"Failed to start HTTP server after {max_retries} retries, exiting...") - os._exit(-1) + await super().startup(sockets=sockets) + if self.servers and self.config.port == 0: + sock = self.servers[0].sockets[0] + self.actual_port = sock.getsockname()[1] + else: + self.actual_port = self.config.port + finally: + self._startup_done.set() + + async def get_port(self) -> int | None: + await self._startup_done.wait() + return self.actual_port + + +async def run_uvicorn(app: FastAPI, server_args, server_address) -> tuple[int, asyncio.Task]: + app.server_args = server_args + config = uvicorn.Config(app, host=server_address, port=0, log_level="warning") + server = _UvicornServerAutoPort(config) + server_task = asyncio.create_task(server.serve()) + server_port = await server.get_port() + if server_port is None: + # server.startup() failed. await the task to re-raise exception from server.serve() + await server_task + # Fails on unexpected situation. + raise RuntimeError("Unexpected: HTTP server started without reporting listened port") logger.info(f"HTTP server started on port {server_port}") return server_port, server_task diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index d2369dc2e89..7fa3b1dd67c 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -163,6 +163,15 @@ def __new__(cls, **kwargs): # 2. patch online fp8 quant if os.environ.get("VERL_VLLM_FP8_QUANT_ENABLED", "0") == "1": apply_vllm_fp8_patches() + # 3. patch QAT (compressed-tensors NVFP4) for dynamic weight loading + vllm_config = kwargs.get("vllm_config") + quant_config = getattr(vllm_config, "quant_config", None) if vllm_config else None + _is_qat_model = getattr(quant_config, "quant_format", None) == "nvfp4-pack-quantized" + if _is_qat_model: + from verl.utils.qat import apply_qat_patches + + apply_qat_patches() + logger.info("Applied QAT patches in vLLM worker subprocess") # TODO: For ascend NPU, when the corresponding vllm-ascend version is upgraded to v0.13.0, # please remove the VLLM_ASCEND_REQUIRED_ENV_VARS variable replacement action. @@ -172,7 +181,9 @@ def __new__(cls, **kwargs): if k not in os.environ: os.environ[k] = VLLM_ASCEND_REQUIRED_ENV_VARS[k] - return super().__new__(cls) + instance = super().__new__(cls) + instance._is_qat_model = _is_qat_model + return instance def monkey_patch_model(self, vocab_size: int): # patch compute_logits to avoid sampling OOV token @@ -214,8 +225,14 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False self.model_runner.vllm_config ) - # Re-apply here because async IPC weight sync can happen long after init and lose MoE weight_loader attrs. - if use_standard_weight_load: + if self._is_qat_model: + # QAT: Prepare for weight loading BEFORE receiving any buckets + from verl.utils.qat import prepare_qat_for_load_weights + + prepare_qat_for_load_weights(self.model_runner.model, device=self.device) + logger.info("QAT: prepare_qat_for_load_weights completed") + elif use_standard_weight_load: + # Re-apply here because async IPC weight sync can happen long after init and lose MoE weight_loader attrs. patch_vllm_moe_model_weight_loader(self.model_runner.model) # receive bucket and update weights @@ -241,7 +258,13 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False if metadata["is_last"]: break - if use_standard_weight_load: + if self._is_qat_model: + # QAT: call process_weights_after_loading AFTER all buckets are received + from verl.utils.qat import manual_process_weights_after_loading + + manual_process_weights_after_loading(self.model_runner.model) + logger.info("QAT: process_weights_after_loading completed") + elif use_standard_weight_load: # Some post-load transforms are non-idempotent; run once after all buckets. from vllm.model_executor.model_loader.utils import process_weights_after_loading @@ -252,6 +275,7 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False # clean up socket.close() del buffer + gc.collect() if shm is not None: shm.close() del shm diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index cf5ab342888..9c2bbad47cf 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -35,7 +35,6 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.async_llm import AsyncLLM -from verl.single_controller.ray import RayClassWithInitArgs from verl.utils.config import omega_conf_to_dataclass from verl.utils.device import get_resource_name, get_visible_devices_keyword from verl.utils.net_utils import get_free_port, is_valid_ipv6_address @@ -43,8 +42,7 @@ from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput -from verl.workers.rollout.utils import get_max_position_embeddings, run_unvicorn -from verl.workers.rollout.vllm_rollout import ServerAdapter +from verl.workers.rollout.utils import get_max_position_embeddings, run_uvicorn from verl.workers.rollout.vllm_rollout.utils import ( VLLM_LORA_INT_ID, VLLM_LORA_NAME, @@ -153,10 +151,10 @@ def __init__( if self.node_rank == 0: self._master_address = self._server_address # used for torch.distributed.init_process_group - self._master_port, self._master_sock = get_free_port(self._server_address) + self._master_port, self._master_sock = get_free_port(self._server_address, with_alive_sock=True) # used for data parallel: --data-parallel-address, --data-parallel-rpc-port - self._dp_rpc_port, self._dp_rpc_sock = get_free_port(self._server_address) - self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address) + self._dp_rpc_port, self._dp_rpc_sock = get_free_port(self._server_address, with_alive_sock=True) + self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address, with_alive_sock=True) else: self._master_address = None self._master_port = None @@ -182,6 +180,12 @@ def get_server_address(self): assert self._server_port is not None, "http server is not launched, port is None" return self._server_address, self._server_port + @property + def lora_as_adapter(self) -> bool: + return ( + self.model_config.lora_rank > 0 or self.model_config.lora.get("rank", 0) > 0 + ) and not self.model_config.lora.get("merge", False) + async def collective_rpc( self, method: str | Callable, @@ -231,8 +235,25 @@ async def launch_server(self, master_address: str = None, master_port: int = Non set_expandable_segments(True) quantization = self.config.quantization + hf_overrides = {} + + # Handle QAT (Quantization-Aware Training) configuration + qat_config_dict = getattr(self.config, "qat", {}) or {} + if qat_config_dict.get("enable", False): + # QAT uses compressed-tensors quantization, apply patches for dynamic weight loading + from verl.utils.qat import QATConfig, apply_qat_patches, load_quantization_config - if quantization is not None: + apply_qat_patches() + + # Load quantization config from JSON file + qat_config = QATConfig(**qat_config_dict) + quantization_config_dict = load_quantization_config(qat_config) + hf_overrides["quantization_config"] = quantization_config_dict + quantization = "compressed-tensors" + + logger.info("QAT quantization config injected to vLLM async server") + elif quantization is not None: + # Handle other quantization methods (fp8, torchao) _SUPPORTED_QUANTIZATION = ["fp8", "torchao"] if quantization not in _SUPPORTED_QUANTIZATION: raise ValueError(f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {quantization}") @@ -250,19 +271,16 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "weight_block_size": [128, 128], "ignored_layers": all_mlp_gate_layers, } - fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) + hf_overrides["quantization_config"] = dict(FP8_BLOCK_QUANT_KWARGS) # Apply vllm fp8 patches # Will remove the patch after vllm support on-the-fly quant for rollout natively. apply_vllm_fp8_patches() # for subprocesses patching os.environ["VERL_VLLM_FP8_QUANT_ENABLED"] = "1" - hf_overrides = {} if quantization is not None and self.config.quantization_config_file is not None: hf_overrides["quantization_config_file"] = self.config.quantization_config_file - if quantization == "fp8": - hf_overrides["quantization_config"] = fp8_block_quant_kwargs compilation_config = engine_kwargs.pop("compilation_config", None) or {} if isinstance(compilation_config, str): compilation_config = json.loads(compilation_config) @@ -410,6 +428,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non # 3. launch server if self.node_rank == 0: self._master_sock.close() + self._dp_rpc_sock.close() + self._dp_master_sock.close() await self.run_server(server_args) else: # TODO: avoid connect before master_sock close @@ -456,7 +476,7 @@ async def run_server(self, args: argparse.Namespace): logger.info(f"Initializing a V1 LLM engine with config: {vllm_config}") self.engine = engine_client - self._server_port, self._server_task = await run_unvicorn(app, args, self._server_address) + self._server_port, self._server_task = await run_uvicorn(app, args, self._server_address) async def run_headless(self, args: argparse.Namespace): """Run headless server in a separate thread.""" @@ -529,9 +549,7 @@ async def generate( # Add lora request lora_request = None - if ( - self.model_config.lora_rank > 0 or self.model_config.lora.get("rank", 0) > 0 - ) and not self.model_config.lora.get("merge", False): + if self.lora_as_adapter: # Make sure we also check that the lora is already loaded in the engine lora_loaded = VLLM_LORA_INT_ID in await self.engine.list_loras() if lora_loaded: @@ -604,7 +622,12 @@ async def sleep(self): if self.rollout_mode == RolloutMode.HYBRID: # Don't use engine.sleep(level=2) here - await self.engine.collective_rpc("sleep", kwargs={"level": 2}) + # lora only update adapter weights, so set sleep level to 1 + if self.lora_as_adapter: + sleep_level = 1 + else: + sleep_level = 2 + await self.engine.collective_rpc("sleep", kwargs={"level": sleep_level}) # clear encoder cache: https://github.com/vllm-project/vllm/pull/33452 # await self.engine.reset_encoder_cache() @@ -751,9 +774,6 @@ async def abort_request(self, request_id: str, reset_prefix_cache: bool = True) return {"aborted": False, "request_id": request_id, "error": str(e)} -_rollout_worker_actor_cls = ray.remote(ServerAdapter) - - class vLLMReplica(RolloutReplica): def __init__( self, @@ -766,16 +786,6 @@ def __init__( super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model) self.server_class = ray.remote(vLLMHttpServer) - def get_ray_class_with_init_args(self) -> RayClassWithInitArgs: - """Get rollout worker actor class for colocated and standalone mode.""" - worker_dict_cls = RayClassWithInitArgs( - cls=_rollout_worker_actor_cls, - config=self.config, - model_config=self.model_config, - device_mesh=None, - ) - return worker_dict_cls - async def launch_servers(self): """Launch http server in each node.""" assert len(self.workers) == self.world_size, ( @@ -822,8 +832,14 @@ async def launch_servers(self): node_id=node_id, soft=False, ), - runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}}, + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + } + }, name=name, + max_concurrency=self.max_concurrency, ).remote( config=self.config, model_config=self.model_config, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 75efb81d892..53a433cc51e 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -72,6 +72,7 @@ def __init__( config: RolloutConfig, model_config: HFModelConfig, device_mesh: DeviceMesh, + replica_rank: int = -1, ): super().__init__(config, model_config, device_mesh) self.server_handle: ray.actor.ActorHandle = None @@ -83,7 +84,10 @@ def __init__( * self.config.data_parallel_size * self.config.pipeline_model_parallel_size ) - self.replica_rank = rank // rollout_world_size + if replica_rank == -1: + self.replica_rank = rank // rollout_world_size + else: + self.replica_rank = replica_rank self.rollout_rank = rank % rollout_world_size self.node_rank = self.rollout_rank // local_world_size @@ -169,7 +173,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None buffer, shm = None, None if not self.use_shm: - buffer = torch.empty(bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:0") + buffer = torch.empty(bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:{get_device_id()}") handle = reduce_tensor(buffer) s.send_pyobj(handle) else: @@ -228,6 +232,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None # clean up s.close() del buffer + gc.collect() if shm is not None: shm.close() shm.unlink() diff --git a/verl/workers/utils/padding.py b/verl/workers/utils/padding.py index d4bbf88e8d8..16242e7731f 100644 --- a/verl/workers/utils/padding.py +++ b/verl/workers/utils/padding.py @@ -105,7 +105,7 @@ def no_padding_2_padding(tensor: torch.Tensor, data: TensorDict) -> torch.Tensor prompt_lens = prompt_ids.offsets().diff() response_lens = response_ids.offsets().diff() if max_response_len < 0: - max_response_len = response_ids.offsets().diff().max().item() + max_response_len = response_lens.max().item() else: assert not attention_mask.is_nested prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1)