From 3e42eaf7d5e84f5a1f341e42b84eb59b8170b253 Mon Sep 17 00:00:00 2001 From: Superjomn <328693+Superjomn@users.noreply.github.com> Date: Wed, 8 Apr 2026 21:52:21 +0800 Subject: [PATCH 1/3] [docker] feat: upgrade Megatron-LM to core_v0.16.0 and TRT-LLM to 1.3.0rc10 - Upgrade Megatron-LM to core_v0.16.0, switch DeepEP branch from v1.2.1 to hybrid-ep (removing now-unnecessary patch), add CCCL CPATH for build compat - Bump TRT-LLM base image from 1.3.0rc4 to 1.3.0rc10 - Pin trl==0.27.0 to fix AutoModelForCausalLMWithValueHead import Co-authored-by: Claude Sonnet 4.6 --- docker/Dockerfile.stable.trtllm | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile.stable.trtllm b/docker/Dockerfile.stable.trtllm index 61c27b911c6..d5913b14842 100644 --- a/docker/Dockerfile.stable.trtllm +++ b/docker/Dockerfile.stable.trtllm @@ -1,9 +1,13 @@ # Base image from NGC TensorRT-LLM, which includes a pre-installed TensorRT-LLM. # For available images, visit: https://nvidia.github.io/TensorRT-LLM/installation/containers.html # Use TRTLLM_BASE_IMAGE to specify the base image (default: release:1.2.0rc6) -ARG TRTLLM_BASE_IMAGE=nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc4 +ARG TRTLLM_BASE_IMAGE=nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 FROM ${TRTLLM_BASE_IMAGE} +# Clear TORCH_CUDA_ARCH_LIST inherited from the base image so that +# FlashInfer's check_cuda_arch() queries the actual GPU at runtime +# instead of rejecting GPUs not in the build-time arch list. +ENV TORCH_CUDA_ARCH_LIST="" # ============================================================================== # Install Megatron dependencies @@ -22,18 +26,17 @@ RUN git clone -b v2.5.1 https://github.com/NVIDIA/gdrcopy.git && \ pushd ${NVSHMEM_DIR}/lib && \ ln -s libnvshmem_host.so.3 libnvshmem_host.so && \ popd && \ - git clone -b v1.2.1 https://github.com/deepseek-ai/DeepEP.git && \ + git clone -b hybrid-ep https://github.com/deepseek-ai/DeepEP.git && \ pushd DeepEP && \ - wget https://raw.githubusercontent.com/NVIDIA/Megatron-LM/refs/tags/core_v0.15.0/docker/patches/deepep.patch && \ - patch -p1 < deepep.patch && \ + export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:$CPATH && \ TORCH_CUDA_ARCH_LIST="9.0 10.0 12.0" python setup.py install && \ popd && rm -rf deepep # Install Python dependencies -RUN pip3 install --no-cache-dir --no-deps trl && \ +RUN pip3 install --no-cache-dir --no-deps trl==0.27.0 && \ pip3 install --no-cache-dir nvtx matplotlib liger_kernel cachetools && \ pip install --no-cache-dir -U git+https://github.com/ISEEKYAN/mbridge.git && \ - pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.15.0 + pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.16.0 # ============================================================================== From 07622572ac9d059f3a3046f3ac3d5ee93beebd16 Mon Sep 17 00:00:00 2001 From: Superjomn <328693+Superjomn@users.noreply.github.com> Date: Wed, 8 Apr 2026 21:52:29 +0800 Subject: [PATCH 2/3] [rollout,trtllm] fix: update rollout for TRT-LLM 1.3.0rc10 compatibility - Fix ExecutorMemoryType import path change in 1.3.0rc10 - Remove model_weights/draft_model_weights from fallback _WEIGHTS_TAGS - Defer ServerAdapter import to avoid FlashInfer crash on CPU orchestrator - Resolve CUTLASS SM 9.0 PTX failures on L20 (SM 8.9) GPUs via sm89 target - Disable PDL, fix MoE backend case, guard FlashInfer import - Update trtllm_worker.rst docs for new API Co-authored-by: Claude Sonnet 4.6 --- docs/workers/trtllm_worker.rst | 13 +++----- .../trtllm_rollout/trtllm_async_server.py | 23 ++++++++++--- .../rollout/trtllm_rollout/trtllm_rollout.py | 33 +++++++++++++------ .../trtllm_rollout/trtllm_worker_extension.py | 26 +++++++++++---- 4 files changed, 66 insertions(+), 29 deletions(-) diff --git a/docs/workers/trtllm_worker.rst b/docs/workers/trtllm_worker.rst index ad6781f5e3b..0319375165a 100644 --- a/docs/workers/trtllm_worker.rst +++ b/docs/workers/trtllm_worker.rst @@ -1,7 +1,7 @@ TensorRT-LLM Backend ==================== -Last updated: 12/31/2025. +Last updated: 4/2/2026. **Authored By TensorRT-LLM Team** @@ -14,7 +14,7 @@ The TensorRT-LLM rollout engine primarily targets the colocated mode. Instead of Installation ------------ -We provide ``docker/Dockerfile.stable.trtllm`` for building a docker image with TensorRT-LLM pre-installed. The verl integration is supported from ``nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc6``, and you can choose other TensorRT-LLM versions via ``TRTLLM_BASE_IMAGE`` from the `NGC Catalog `_. +We provide `docker/Dockerfile.stable.trtllm `_ for building a docker image with TensorRT-LLM pre-installed. The verl integration is supported from ``nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc6``, and you can choose other TensorRT-LLM versions via ``TRTLLM_BASE_IMAGE`` from the `NGC Catalog `_. Alternatively, refer to the `TensorRT-LLM installation guide `_ for compatible environments if you want to build your own. @@ -51,12 +51,7 @@ We provide the following GRPO recipe scripts for you to test the performance and Using TensorRT-LLM as the Rollout Engine for DAPO ------------------------------------------------- -We provide a DAPO recipe script ``recipe/dapo/test_dapo_7b_math_trtllm.sh``. - .. code-block:: bash - ## For FSDP training engine - bash recipe/dapo/test_dapo_7b_math_trtllm.sh - ## For Megatron-Core training engine - TRAIN_ENGINE=megatron bash recipe/dapo/test_dapo_7b_math_trtllm.sh - + # For Megatron-Core training engine with FP8 rollout + bash examples/grpo_trainer/run_qwen3-30b_dapo_megatron_fp8_trtllm.sh diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 5b835dc2093..1a740ebdb03 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -28,7 +28,6 @@ from verl.utils.net_utils import is_valid_ipv6_address from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput -from verl.workers.rollout.trtllm_rollout.trtllm_rollout import ServerAdapter from verl.workers.rollout.utils import get_max_position_embeddings, qwen2_5_vl_dedup_image_tokens, run_uvicorn logger = logging.getLogger(__file__) @@ -117,6 +116,12 @@ def get_server_address(self): async def launch_server(self): from tensorrt_llm import AsyncLLM from tensorrt_llm.llmapi import CapacitySchedulerPolicy, CudaGraphConfig, KvCacheConfig, SchedulerConfig + + try: + from tensorrt_llm.llmapi.llm_args import ExecutorMemoryType, SleepConfig + except ImportError: + ExecutorMemoryType = None + SleepConfig = None from tensorrt_llm.serve import OpenAIServer assert self.config.pipeline_model_parallel_size == 1, "pipeline_model_parallel_size > 1 is not supported yet" @@ -164,7 +169,14 @@ async def launch_server(self): "placement_groups": self.pgs, "placement_bundle_indices": self.bundle_indices, "per_worker_gpu_share": per_worker_gpu_share, - "enable_sleep": self.config.enable_sleep_mode, + "sleep_config": SleepConfig( + restore_modes={ + ExecutorMemoryType.MODEL_WEIGHTS_MAIN: "NONE", + ExecutorMemoryType.KV_CACHE: "NONE", + } + ) + if self.config.enable_sleep_mode and SleepConfig is not None + else None, "allreduce_strategy": "NCCL", "sampler_type": "TRTLLMSampler", **engine_kwargs, @@ -285,6 +297,8 @@ async def resume_generation(self): raise NotImplementedError async def wake_up(self): + from verl.workers.rollout.trtllm_rollout.trtllm_rollout import ServerAdapter + if self.rollout_mode == RolloutMode.HYBRID: # In hybrid mode, rollout is wake up in `update_weights` raise ValueError(f"wake_up not support rollout_mode {self.rollout_mode}") @@ -294,6 +308,8 @@ async def wake_up(self): logger.info("skip wake_up in standalone mode") async def sleep(self): + from verl.workers.rollout.trtllm_rollout.trtllm_rollout import ServerAdapter + if not self.config.free_cache_engine: return @@ -348,8 +364,8 @@ def get_pgs_and_bundle_indices(self) -> tuple[list[PlacementGroup], list[list[in local_bundle_index = self.world_size * self.replica_rank while local_bundle_index >= self.resource_pool.pgs[start_pg_index].bundle_count: - start_pg_index += 1 local_bundle_index -= self.resource_pool.pgs[start_pg_index].bundle_count + start_pg_index += 1 assert ( start_pg_index < len(self.resource_pool.pgs) and local_bundle_index < self.resource_pool.pgs[start_pg_index].bundle_count @@ -386,7 +402,6 @@ def get_pgs_and_bundle_indices(self) -> tuple[list[PlacementGroup], list[list[in return pgs, bundle_indices async def launch_servers(self): - assert self.nnodes == 1, "TRTLLMReplica doesn't support multiple nodes for single replica yet." assert self.resource_pool.pgs is not None, "placement groups are not initialized" pgs, bundle_indices = self.get_pgs_and_bundle_indices() diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 606cb4b019f..4fbf0e614f4 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -29,6 +29,12 @@ import ray import torch import torch.distributed as dist + +try: + from tensorrt_llm.llmapi.llm_args import ExecutorMemoryType +except (ImportError, RuntimeError): + # RuntimeError: FlashInfer's check_cuda_arch() crashes on CPU-only actors + ExecutorMemoryType = None from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.multiprocessing.reductions import reduce_tensor @@ -260,16 +266,23 @@ async def update_weights(self, weights: dict[str, str]): class ServerAdapter(BaseRollout): - _WEIGHTS_TAGS = [ - "sampler", - "drafter", - "guided_decoder", - "spec_resource_manager", - "model_extra", - "executor_extra", - "model", - "draft_model", - ] + # All releasable/resumable weight tags: every ExecutorMemoryType except kv_cache + # (handled separately) and internal tags prefixed with "_". + # Fallback to hard-coded list for trtllm versions that don't export ExecutorMemoryType. + _WEIGHTS_TAGS = ( + [t.value for t in ExecutorMemoryType if t is not ExecutorMemoryType.KV_CACHE and not t.value.startswith("_")] + if ExecutorMemoryType is not None + else [ + "sampler", + "drafter", + "guided_decoder", + "spec_resource_manager", + "model_extra", + "executor_extra", + "model", + "draft_model", + ] + ) @staticmethod def get_full_tags() -> list[str]: diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index 4beb85f70e2..0fc5a06cd8e 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -15,12 +15,26 @@ import inspect from typing import Optional -from tensorrt_llm import serialization -from tensorrt_llm._ray_utils import control_action_decorator -from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer -from tensorrt_llm._torch.utils import get_device_uuid -from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension as TrtllmWorkerExtension -from tensorrt_llm.logger import logger +# Defer tensorrt_llm imports to avoid FlashInfer's check_cuda_arch() crash +# when this module is loaded on CPU-only Ray actors. The module is normally +# loaded only on GPU workers via string path in trtllm_async_server.py, but +# guard defensively in case of transitive imports. +try: + from tensorrt_llm import serialization + from tensorrt_llm._ray_utils import control_action_decorator + from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer + from tensorrt_llm._torch.utils import get_device_uuid + from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension as TrtllmWorkerExtension + from tensorrt_llm.logger import logger +except (ImportError, RuntimeError): + # On CPU actors without CUDA, these imports may fail. + # The class below won't be usable, but the module can be imported safely. + serialization = None + control_action_decorator = lambda f: f # noqa: E731 — identity fallback + MoeLoadBalancer = None + get_device_uuid = None + TrtllmWorkerExtension = object + logger = None class WorkerExtension(TrtllmWorkerExtension): From 687adb08c3b6477a0801378c35af6ccb53f8204c Mon Sep 17 00:00:00 2001 From: Superjomn <328693+Superjomn@users.noreply.github.com> Date: Wed, 8 Apr 2026 21:52:38 +0800 Subject: [PATCH 3/3] [ci] fix: update trtllm e2e CI workflow for 1.3.0rc10 - Bump CI image tag to 1.3.0rc10 - Restore TORCH_CUDA_ARCH_LIST as runtime fallback for Ray workers Co-authored-by: Claude Sonnet 4.6 --- .../workflows/e2e_ppo_grpo_trainer_trtllm.yml | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml b/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml index 3c2e62f99cd..079ec2dd770 100644 --- a/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml +++ b/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml @@ -93,7 +93,7 @@ permissions: contents: read env: - IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:trtllm1.3.0rc4" + IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:trtllm1.3.0rc10" DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" jobs: @@ -121,10 +121,15 @@ jobs: NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + TORCH_CUDA_ARCH_LIST: "7.5;8.0;8.9;9.0;10.0;12.0+PTX" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: fetch-depth: 0 + - name: Reinstall FlashInfer for runtime GPU arch + run: | + unset TORCH_CUDA_ARCH_LIST + pip3 install --force-reinstall --no-deps flashinfer-python - name: Install the current repository run: | pip3 install pytest-asyncio @@ -132,6 +137,8 @@ jobs: pip3 install --no-deps -e . - name: Run TRTLLM unit tests run: | + unset TORCH_CUDA_ARCH_LIST # Let TRT-LLM CUTLASS DSL auto-detect runtime GPU arch + export TRTLLM_ENABLE_PDL=0 # Disable Programmatic Dependent Launch (uses SM 9.0 griddepcontrol) export TRTLLM_TEST_MODEL_PATH_ROOT="${HOME}/models" ray stop --force pytest -v -s --durations=20 \ @@ -148,10 +155,15 @@ jobs: NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + TORCH_CUDA_ARCH_LIST: "7.5;8.0;8.9;9.0;10.0;12.0+PTX" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: fetch-depth: 0 + - name: Reinstall FlashInfer for runtime GPU arch + run: | + unset TORCH_CUDA_ARCH_LIST + pip3 install --force-reinstall --no-deps flashinfer-python - name: Install the current repository run: | pip3 install -r requirements-test.txt @@ -161,6 +173,8 @@ jobs: python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k --local_save_dir ${PWD}/data/gsm8k - name: Running GSM8K E2E training tests with FSDP on 8 L20 GPUs (Qwen) run: | + unset TORCH_CUDA_ARCH_LIST # Let TRT-LLM CUTLASS DSL auto-detect runtime GPU arch + export TRTLLM_ENABLE_PDL=0 # Disable Programmatic Dependent Launch (uses SM 9.0 griddepcontrol) ray stop --force DATADIR=${HOME}/data \ bash examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh 2 \ @@ -193,10 +207,15 @@ jobs: NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + TORCH_CUDA_ARCH_LIST: "7.5;8.0;8.9;9.0;10.0;12.0+PTX" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: fetch-depth: 0 + - name: Reinstall FlashInfer for runtime GPU arch + run: | + unset TORCH_CUDA_ARCH_LIST + pip3 install --force-reinstall --no-deps flashinfer-python - name: Install the current repository run: | pip3 install -r requirements-test.txt @@ -206,6 +225,8 @@ jobs: python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k --local_save_dir ${PWD}/data/gsm8k - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) run: | + unset TORCH_CUDA_ARCH_LIST # Let TRT-LLM CUTLASS DSL auto-detect runtime GPU arch + export TRTLLM_ENABLE_PDL=0 # Disable Programmatic Dependent Launch (uses SM 9.0 griddepcontrol) ray stop --force DATADIR=${HOME}/data \ ACTOR_TP=2 \ @@ -235,10 +256,15 @@ jobs: NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + TORCH_CUDA_ARCH_LIST: "7.5;8.0;8.9;9.0;10.0;12.0+PTX" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: fetch-depth: 0 + - name: Reinstall FlashInfer for runtime GPU arch + run: | + unset TORCH_CUDA_ARCH_LIST + pip3 install --force-reinstall --no-deps flashinfer-python - name: Install the current repository run: | pip3 install -r requirements-test.txt @@ -268,9 +294,11 @@ jobs: python3 examples/data_preprocess/aime2024_multiturn_w_tool.py --local_save_dir ${PWD}/data/aime-2024 - name: Running DAPO E2E with FP8 TRT-LLM rollout (Qwen3-0.6B) run: | + unset TORCH_CUDA_ARCH_LIST # Let TRT-LLM CUTLASS DSL auto-detect runtime GPU arch + export TRTLLM_ENABLE_PDL=0 # Disable Programmatic Dependent Launch (uses SM 9.0 griddepcontrol) ray stop --force export INFER_TP=2 ACTOR_TP=2 ACTOR_PP=2 ACTOR_VPP=2 ACTOR_EP=1 ACTOR_CP=2 REF_TP=2 REF_PP=2 REF_VPP=2 REF_EP=1 REF_CP=2 GEN_MOE_TP=null GEN_MOE_EP=null - export NNODES=1 GPUS_PER_NODE=8 TRTLLM_MOE_BACKEND=CUTLASS + export NNODES=1 GPUS_PER_NODE=8 TRTLLM_MOE_BACKEND=TRITON export DATA_DIR=${PWD} DAPO_MATH_TRAIN=${PWD}/data/dapo-math-17k/train.parquet AIME_VAL=${PWD}/data/aime-2024/train.parquet MODEL_PATH=${HOME}/models/Qwen/Qwen3-0.6B bash examples/grpo_trainer/run_qwen3-30b_dapo_megatron_fp8_trtllm.sh \ reward_model.reward_kwargs.overlong_buffer_cfg.len=258 \