From 738fc757f407c62b0c8130823a20e882b6324b27 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 20:16:07 -0700 Subject: [PATCH 01/19] nemo gym integration Signed-off-by: cmunley1 --- .gitignore | 4 +- .gitmodules | 3 + 3rdparty/nemo_gym | 1 + config.env.example | 9 + submit_math.sh | 169 ++++++ submit_workplace.sh | 172 ++++++ verl/experimental/nemo_gym_agent_loop.py | 572 ++++++++++++++++++ verl/experimental/nemo_gym_dataset.py | 84 +++ .../hermes_tool_parser_patched.py | 57 ++ verl/workers/config/rollout.py | 4 +- .../rollout/vllm_rollout/vllm_async_server.py | 13 + 11 files changed, 1086 insertions(+), 2 deletions(-) create mode 160000 3rdparty/nemo_gym create mode 100644 config.env.example create mode 100755 submit_math.sh create mode 100755 submit_workplace.sh create mode 100644 verl/experimental/nemo_gym_agent_loop.py create mode 100644 verl/experimental/nemo_gym_dataset.py create mode 100644 verl/experimental/tool_parsers/hermes_tool_parser_patched.py diff --git a/.gitignore b/.gitignore index e6c0f5a08e3..0aa648b904e 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,6 @@ ENV/ logs log outputs -.history \ No newline at end of file +.history/ +logs/ +config.env diff --git a/.gitmodules b/.gitmodules index d5dd7a6aa57..de9f6e53919 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "recipe"] path = recipe url = https://github.com/verl-project/verl-recipe.git +[submodule "3rdparty/nemo_gym"] + path = 3rdparty/nemo_gym + url = https://github.com/NVIDIA-NeMo/Gym diff --git a/3rdparty/nemo_gym b/3rdparty/nemo_gym new file mode 160000 index 00000000000..f1399809b86 --- /dev/null +++ b/3rdparty/nemo_gym @@ -0,0 +1 @@ +Subproject commit f1399809b86c8d3f669bf3cf5c24efabec62dfad diff --git a/config.env.example b/config.env.example new file mode 100644 index 00000000000..0efe623aefc --- /dev/null +++ b/config.env.example @@ -0,0 +1,9 @@ +# Copy this to config.env and fill in your values. config.env is gitignored. + +VERL_ROOT=/path/to/verl +NEMO_GYM_ROOT=/path/to/gym-ref +HF_HOME=/path/to/hf_home +RESULTS_ROOT=/path/to/results +DATA_ROOT=/path/to/data +WANDB_USERNAME=your_wandb_username +WANDB_API_KEY=your_key_here diff --git a/submit_math.sh b/submit_math.sh new file mode 100755 index 00000000000..cff7f20b915 --- /dev/null +++ b/submit_math.sh @@ -0,0 +1,169 @@ +#!/bin/bash +#SBATCH --job-name=verl-nemogym-dapo-7b-math +#SBATCH --nodes=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --time=4:00:00 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --output=logs/slurm-%j.out +#SBATCH --error=logs/slurm-%j.err + +set -euo pipefail + +GPUS_PER_NODE=8 + +source "$(dirname "$0")/config.env" + +MODEL_PATH="${DATA_ROOT}/models/Qwen2.5-Math-7B" +TRAIN_FILE="${NEMO_GYM_ROOT}/resources_servers/math_with_judge/data/dapo17k_bytedtsinghua_train_nrl.jsonl" +TEST_FILE="${NEMO_GYM_ROOT}/resources_servers/math_with_judge/data/aime24_bytedtsinghua_validation_nrl.jsonl" +CKPTS_DIR="${RESULTS_ROOT}/DAPO-Qwen2.5-7b-MATH-megatron" + +CONTAINER="verlai/verl:vllm017.latest" +MOUNTS="/lustre:/lustre" + +mkdir -p "${CKPTS_DIR}" + +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address | awk '{print $1}') + +RAY_PORT=6379 +ip_head="${head_node_ip}:${RAY_PORT}" +echo "Head node: ${head_node} (${head_node_ip})" + +SRUN_ARGS="--no-container-mount-home --container-image=${CONTAINER} --container-mounts=${MOUNTS} --container-workdir=${VERL_ROOT}" + +echo "Starting Ray head on ${head_node}..." +srun --nodes=1 --ntasks=1 -w "${head_node}" ${SRUN_ARGS} --container-name=ray-head \ + env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start --head \ + --node-ip-address="${head_node_ip}" \ + --port=${RAY_PORT} \ + --num-gpus="${GPUS_PER_NODE}" \ + --block & +sleep 10 + +worker_num=$((SLURM_JOB_NUM_NODES - 1)) +for ((i = 1; i <= worker_num; i++)); do + node_i=${nodes_array[$i]} + echo "Starting Ray worker ${i} on ${node_i}..." + srun --nodes=1 --ntasks=1 -w "${node_i}" ${SRUN_ARGS} \ + env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start \ + --address="${ip_head}" \ + --num-gpus="${GPUS_PER_NODE}" \ + --block & + sleep 5 +done + +CONTAINER_DIR="/raid/enroot/data/user-${UID}/pyxis_${SLURM_JOB_ID}_ray-head" +echo "Waiting for ray-head container at ${CONTAINER_DIR}..." +elapsed=0 +while [[ ! -d "${CONTAINER_DIR}" && ${elapsed} -lt 300 ]]; do + sleep 5 + elapsed=$((elapsed + 5)) +done +if [[ ! -d "${CONTAINER_DIR}" ]]; then + echo "ERROR: ray-head container never appeared after 300s" + exit 1 +fi +echo "Container ready. Waiting 90s for all Ray workers to connect..." +sleep 90 + +echo "Installing nemo-gym..." +srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ + --no-container-mount-home --container-mounts=${MOUNTS} \ + --container-name=ray-head \ + bash -c "touch ${NEMO_GYM_ROOT}/scripts/__init__.py && pip install -q uv && echo 'blinker==1.4' > /tmp/constraints.txt && pip install -q -e ${NEMO_GYM_ROOT} -c /tmp/constraints.txt" + +echo "Launching training on ${head_node}..." +PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ + --no-container-mount-home --container-mounts=${MOUNTS} \ + --container-workdir=${VERL_ROOT} --container-name=ray-head \ + env -u ROCR_VISIBLE_DEVICES \ + WANDB_API_KEY="${WANDB_API_KEY}" \ + HF_HOME="${HF_HOME}" \ + HF_HUB_CACHE="${HF_HOME}/hub" \ + RAY_ADDRESS="auto" \ + VLLM_USE_V1=1 \ + TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ + PYTHONPATH="${NEMO_GYM_ROOT}" \ + RAY_grpc_keepalive_time_ms=60000 \ + RAY_grpc_keepalive_timeout_ms=600000 \ + RAY_grpc_client_keepalive_time_ms=60000 \ + RAY_grpc_client_keepalive_timeout_ms=600000 \ + python3 -m verl.trainer.main_ppo \ + --config-path="${VERL_ROOT}/recipe/dapo/config" \ + --config-name=dapo_megatron_trainer.yaml \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + +data.custom_cls.path="${VERL_ROOT}/verl/experimental/nemo_gym_dataset.py" \ + +data.custom_cls.name=NemoGymJSONLDataset \ + data.truncation=left \ + data.max_prompt_length=2048 \ + data.max_response_length=8192 \ + data.train_batch_size=512 \ + actor_rollout_ref.rollout.n=16 \ + algorithm.adv_estimator=grpo \ + algorithm.use_kl_in_reward=False \ + algorithm.kl_ctrl.kl_coef=0.0 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.28 \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=token-mean \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=10240 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=1.0 \ + actor_rollout_ref.rollout.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ + actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.ref.megatron.param_offload=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=True \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=4096 \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=1.0 \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=8192 \ + 'trainer.logger=["console","wandb"]' \ + trainer.project_name=${WANDB_USERNAME}-verl-nemogym-int \ + trainer.experiment_name=dapo-7b-nemogym \ + trainer.n_gpus_per_node=${GPUS_PER_NODE} \ + trainer.nnodes=${SLURM_JOB_NUM_NODES} \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=10 \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 \ + +actor_rollout_ref.rollout.agent.agent_loop_manager_class='verl.experimental.nemo_gym_agent_loop.NemoGymAgentLoopManager' \ + "+actor_rollout_ref.rollout.agent.nemo_gym.initial_global_config_dict.config_paths=[${NEMO_GYM_ROOT}/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,${NEMO_GYM_ROOT}/resources_servers/math_with_judge/configs/math_with_judge.yaml]" \ + +actor_rollout_ref.rollout.agent.nemo_gym.uses_reasoning_parser=False \ + +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root="${NEMO_GYM_ROOT}" \ + 2>&1 diff --git a/submit_workplace.sh b/submit_workplace.sh new file mode 100755 index 00000000000..ec85a6bab74 --- /dev/null +++ b/submit_workplace.sh @@ -0,0 +1,172 @@ +#!/bin/bash +#SBATCH --job-name=verl-nemogym-dapo-4b-workplace +#SBATCH --nodes=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --time=4:00:00 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --output=logs/slurm-%j.out +#SBATCH --error=logs/slurm-%j.err + +set -euo pipefail + +GPUS_PER_NODE=8 + +source "$(dirname "$0")/config.env" + +MODEL_PATH="${HF_HOME}/hub/models--Qwen--Qwen3-4B-Instruct-2507/snapshots/cdbee75f17c01a7cc42f958dc650907174af0554" +TRAIN_FILE="${DATA_ROOT}/workplace_assistant/train.jsonl" +TEST_FILE="${DATA_ROOT}/workplace_assistant/validation.jsonl" +CKPTS_DIR="${RESULTS_ROOT}/dapo-qwen3-4b-workplace" + +CONTAINER="verlai/verl:vllm017.latest" +MOUNTS="/lustre:/lustre" + +mkdir -p "${CKPTS_DIR}" + +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address | awk '{print $1}') + +RAY_PORT=6379 +ip_head="${head_node_ip}:${RAY_PORT}" +echo "Head node: ${head_node} (${head_node_ip})" + +SRUN_ARGS="--no-container-mount-home --container-image=${CONTAINER} --container-mounts=${MOUNTS} --container-workdir=${VERL_ROOT}" + +echo "Starting Ray head on ${head_node}..." +srun --nodes=1 --ntasks=1 -w "${head_node}" ${SRUN_ARGS} --container-name=ray-head \ + env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start --head \ + --node-ip-address="${head_node_ip}" \ + --port=${RAY_PORT} \ + --num-gpus="${GPUS_PER_NODE}" \ + --block & +sleep 10 + +worker_num=$((SLURM_JOB_NUM_NODES - 1)) +for ((i = 1; i <= worker_num; i++)); do + node_i=${nodes_array[$i]} + echo "Starting Ray worker ${i} on ${node_i}..." + srun --nodes=1 --ntasks=1 -w "${node_i}" ${SRUN_ARGS} \ + env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start \ + --address="${ip_head}" \ + --num-gpus="${GPUS_PER_NODE}" \ + --block & + sleep 5 +done + +CONTAINER_DIR="/raid/enroot/data/user-${UID}/pyxis_${SLURM_JOB_ID}_ray-head" +echo "Waiting for ray-head container at ${CONTAINER_DIR}..." +elapsed=0 +while [[ ! -d "${CONTAINER_DIR}" && ${elapsed} -lt 300 ]]; do + sleep 5 + elapsed=$((elapsed + 5)) +done +if [[ ! -d "${CONTAINER_DIR}" ]]; then + echo "ERROR: ray-head container never appeared after 300s" + exit 1 +fi +echo "Container ready. Waiting 90s for all Ray workers to connect..." +sleep 90 + +echo "Installing nemo-gym..." +srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ + --no-container-mount-home --container-mounts=${MOUNTS} \ + --container-name=ray-head \ + bash -c "PYTHONPATH= touch ${NEMO_GYM_ROOT}/scripts/__init__.py && pip install -q uv && echo 'blinker==1.4' > /tmp/constraints.txt && pip install -q -e ${NEMO_GYM_ROOT} -c /tmp/constraints.txt" + +echo "Launching training on ${head_node}..." +PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ + --no-container-mount-home --container-mounts=${MOUNTS} \ + --container-workdir=${VERL_ROOT} --container-name=ray-head \ + env -u ROCR_VISIBLE_DEVICES \ + WANDB_API_KEY="${WANDB_API_KEY}" \ + HF_HOME="${HF_HOME}" \ + HF_HUB_CACHE="${HF_HOME}/hub" \ + RAY_ADDRESS="auto" \ + VLLM_USE_V1=1 \ + TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ + VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + PYTHONPATH="${NEMO_GYM_ROOT}" \ + RAY_grpc_keepalive_time_ms=60000 \ + RAY_grpc_keepalive_timeout_ms=600000 \ + RAY_grpc_client_keepalive_time_ms=60000 \ + RAY_grpc_client_keepalive_timeout_ms=600000 \ + python3 -m verl.trainer.main_ppo \ + --config-path="${VERL_ROOT}/recipe/dapo/config" \ + --config-name=dapo_megatron_trainer.yaml \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + +data.custom_cls.path="${VERL_ROOT}/verl/experimental/nemo_gym_dataset.py" \ + +data.custom_cls.name=NemoGymJSONLDataset \ + data.truncation=left \ + data.train_batch_size=512 \ + actor_rollout_ref.rollout.n=16 \ + algorithm.adv_estimator=grpo \ + algorithm.use_kl_in_reward=False \ + algorithm.kl_ctrl.kl_coef=0.0 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.28 \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=token-mean \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=10240 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=1.0 \ + actor_rollout_ref.rollout.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ + actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + '+actor_rollout_ref.rollout.engine_kwargs.vllm.enable-auto-tool-choice=true' \ + '+actor_rollout_ref.rollout.engine_kwargs.vllm.tool-call-parser=hermes_patched' \ + "+actor_rollout_ref.rollout.engine_kwargs.vllm.tool-parser-plugin=${VERL_ROOT}/verl/experimental/tool_parsers/hermes_tool_parser_patched.py" \ + '+actor_rollout_ref.rollout.engine_kwargs.vllm.max-model-len=32768' \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.ref.megatron.param_offload=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=True \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=4096 \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=1.0 \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=8192 \ + 'trainer.logger=["console","wandb"]' \ + trainer.project_name=${WANDB_USERNAME}-verl-nemogym-int \ + trainer.experiment_name=dapo-qwen3-4b-workplace \ + trainer.n_gpus_per_node=${GPUS_PER_NODE} \ + trainer.nnodes=${SLURM_JOB_NUM_NODES} \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=10 \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 \ + +actor_rollout_ref.rollout.agent.agent_loop_manager_class='verl.experimental.nemo_gym_agent_loop.NemoGymAgentLoopManager' \ + "+actor_rollout_ref.rollout.agent.nemo_gym.initial_global_config_dict.config_paths=[${NEMO_GYM_ROOT}/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,${NEMO_GYM_ROOT}/resources_servers/workplace_assistant/configs/workplace_assistant.yaml]" \ + +actor_rollout_ref.rollout.agent.nemo_gym.uses_reasoning_parser=False \ + +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root="${NEMO_GYM_ROOT}" \ + 2>&1 diff --git a/verl/experimental/nemo_gym_agent_loop.py b/verl/experimental/nemo_gym_agent_loop.py new file mode 100644 index 00000000000..40882951f69 --- /dev/null +++ b/verl/experimental/nemo_gym_agent_loop.py @@ -0,0 +1,572 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import asyncio +import os +import socket +import sys +import threading +from typing import Optional + +import ray +import torch + +from verl.experimental.agent_loop.agent_loop import ( + AgentLoopManager, + AgentLoopMetrics, + _InternalAgentLoopOutput, +) +from verl.experimental.agent_loop.agent_loop import ( + AgentLoopWorker as _AgentLoopWorker, +) +from verl.protocol import DataProto +from verl.utils.model import compute_position_id_with_mask +from verl.utils.ray_utils import auto_await + +_postprocess = _AgentLoopWorker._postprocess + + +class NemoGymAgentLoopManager(AgentLoopManager): + @classmethod + @auto_await + async def create( + cls, + config, + worker_group=None, + rollout_resource_pool=None, + reward_loop_worker_handles=None, + teacher_model_manager=None, + ) -> NemoGymAgentLoopManager: + instance = cls( + config, + worker_group, + rollout_resource_pool, + teacher_model_manager, + reward_loop_worker_handles, + ) + await instance._initialize_llm_servers() + await instance._init_global_load_balancer() + await instance._init_nemo_gym() + return instance + + async def _init_nemo_gym(self) -> None: + nemo_gym_cfg = self.rollout_config.agent.get("nemo_gym", {}) + + # PYTHONPATH env var isn't reliable inside Ray actor processes (already running + # when env var is set); sys.path injection works instead. + # TODO: seems hacky. proper method maybe is pip install? but thats a bit broken + nemo_gym_root = nemo_gym_cfg.get("nemo_gym_root", None) + if nemo_gym_root and str(nemo_gym_root) not in sys.path: + sys.path.insert(0, str(nemo_gym_root)) + + try: + from nemo_gym.cli import GlobalConfigDictParserConfig, RunHelper + from nemo_gym.rollout_collection import RolloutCollectionHelper + from nemo_gym.server_utils import HEAD_SERVER_KEY_NAME, BaseServerConfig + from omegaconf import DictConfig + except ModuleNotFoundError as e: + if "nemo_gym" in str(e): + raise ImportError( + "nemo-gym not found. Set nemo_gym_root in config: " + "+actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root=/path/to/gym-ref" + ) from e + raise + + initial_global_cfg = dict(nemo_gym_cfg.get("initial_global_config_dict", {})) + + uses_reasoning_parser = bool(nemo_gym_cfg.get("uses_reasoning_parser", False)) + vllm_model_cfg = ( + initial_global_cfg.setdefault("policy_model", {}) + .setdefault("responses_api_models", {}) + .setdefault("vllm_model", {}) + ) + vllm_model_cfg["uses_reasoning_parser"] = uses_reasoning_parser + + # Disable thinking if no reasoning parser + if not uses_reasoning_parser: + vllm_model_cfg.setdefault("extra_body", {}).setdefault("chat_template_kwargs", {})["enable_thinking"] = ( + False + ) + + base_urls = [ + (addr if addr.startswith("http") else f"http://{addr}").rstrip("/") + "/v1" + for addr in self.server_addresses + ] + initial_global_cfg["policy_model_name"] = self.model_config.get("path", "") + initial_global_cfg["policy_api_key"] = "dummy_key" + initial_global_cfg["policy_base_url"] = base_urls + initial_global_cfg.setdefault("global_aiohttp_connector_limit_per_host", 16_384) + initial_global_cfg.setdefault("global_aiohttp_connector_limit", 65_536) + + ray_context = ray.get_runtime_context() + initial_global_cfg["ray_head_node_address"] = ray_context.gcs_address + + if nemo_gym_root: + initial_global_cfg.setdefault("uv_venv_dir", str(nemo_gym_root)) + initial_global_cfg.setdefault("skip_venv_if_present", True) + + node_ip = socket.gethostbyname(socket.gethostname()) + with socket.socket() as s: + s.bind(("", 0)) + head_port = s.getsockname()[1] + initial_global_cfg[HEAD_SERVER_KEY_NAME] = {"host": "0.0.0.0", "port": head_port} + self._head_server_config = BaseServerConfig(host=node_ip, port=head_port) + + if nemo_gym_root: + existing = os.environ.get("PYTHONPATH", "") + os.environ["PYTHONPATH"] = f"{nemo_gym_root}:{existing}" if existing else str(nemo_gym_root) + + # Auto-detect agent ref. maybe dangerous for multi-environment + # TODO test multienv + self._default_agent_ref = None + config_paths = initial_global_cfg.get("config_paths", []) + for config_path in config_paths if isinstance(config_paths, list) else []: + try: + from omegaconf import OmegaConf + + yaml_cfg = OmegaConf.load(config_path) + for key in yaml_cfg: + entry = yaml_cfg[key] + if isinstance(entry, dict) and "responses_api_agents" in entry: + self._default_agent_ref = {"type": "responses_api_agents", "name": key} + print(f"[NemoGymAgentLoopManager] Detected agent: {key}") + break + except Exception: + pass + if self._default_agent_ref: + break + + # start nemo gym servers + self._rh = RunHelper() + self._rh.start( + global_config_dict_parser_config=GlobalConfigDictParserConfig( + dotenv_path=None, + initial_global_config_dict=DictConfig(initial_global_cfg), + skip_load_from_cli=True, + ) + ) + + self._rch = RolloutCollectionHelper() + + # AgentLoopManager stores model_config as raw DictConfig; convert here + # since AgentLoopWorker which normally does this isn't used. + from verl.utils.config import omega_conf_to_dataclass + + self._tokenizer = omega_conf_to_dataclass(self.model_config).tokenizer + + # asyncio.run() was recreating loop on each step and erroring.. so make a single persistent loop + self._rollout_loop = asyncio.new_event_loop() + self._rollout_thread = threading.Thread( + target=self._rollout_loop.run_forever, + daemon=True, + name="nemo-gym-rollout-loop", + ) + self._rollout_thread.start() + + print(f"NemoGymAgentLoopManager ready: {len(base_urls)} vLLM endpoints: {base_urls}") + + def generate_sequences(self, prompts: DataProto) -> DataProto: + future = asyncio.run_coroutine_threadsafe(self._async_generate_sequences(prompts), self._rollout_loop) + return future.result() + + async def _async_generate_sequences(self, prompts: DataProto) -> DataProto: + validate = prompts.meta_info.get("validate", False) + global_steps = prompts.meta_info.get("global_steps", -1) + + nemo_gym_examples = _build_nemo_gym_examples( + prompts, + self.rollout_config, + validate=validate, + default_agent_ref=self._default_agent_ref, + ) + + # run rollout collection + nemo_gym_result_iterator = self._rch.run_examples( + examples=nemo_gym_examples, + head_server_config=self._head_server_config, + ) + + # collect results + rowidxs, raw_results = [], [] + for task in nemo_gym_result_iterator: + nemo_gym_row, nemo_gym_result = await task + try: + result = _postprocess_nemo_gym_result(nemo_gym_result, self._tokenizer) + except ValueError: + # Context length exceeded. return a dummy result (1 token, 0 reward, 0 logprob) + # TODO: should we fail here instead or what? i dont like dummy result. what does nemo rl do? + result = _empty_result(nemo_gym_row, self._tokenizer) + rowidxs.append(nemo_gym_row["_rowidx"]) + raw_results.append(result) + + results = [None] * len(nemo_gym_examples) + for rowidx, result in zip(rowidxs, raw_results, strict=False): + results[rowidx] = result + + # pad to batch max instead of setting data.max_prompt_length + # TODO: review what we are padding here, is this dangerous or wasteful? + prompt_lens = [sum(len(m["token_ids"]) for m in r["input_message_log"]) for r in results] + response_lens = [ + sum(len(m["token_ids"]) for m in r["message_log"][len(r["input_message_log"]) :]) for r in results + ] + prompt_length = max(prompt_lens) if prompt_lens else self.rollout_config.prompt_length + response_length = max(response_lens) if response_lens else self.rollout_config.response_length + + internal_outputs = [ + _nemo_gym_result_to_verl( + result=result, + tokenizer=self._tokenizer, + prompt_length=prompt_length, + response_length=response_length, + ) + for result in results + ] + + output = _postprocess( + self, + internal_outputs, + input_non_tensor_batch=prompts.non_tensor_batch, + validate=validate, + ) + output.meta_info["global_steps"] = global_steps + # ray_trainer.py merges this into its own timing_raw at line ~1362; empty is fine + output.meta_info["timing"] = {} + output.meta_info["rollout_metrics"] = _compute_rollout_metrics( + results, getattr(self.rollout_config, "max_model_len", None) + ) + return output + + +def _build_nemo_gym_examples( + prompts: DataProto, + rollout_config, + validate: bool = False, + default_agent_ref=None, +) -> list[dict]: + cfg = rollout_config + temperature = cfg.val_kwargs.temperature if validate else cfg.temperature + top_p = cfg.val_kwargs.top_p if validate else cfg.top_p + + non_tensor = prompts.non_tensor_batch + examples = [] + for i in range(len(prompts)): + messages = list(non_tensor["raw_prompt"][i]) + + if "agent_ref" in non_tensor: + agent_ref = non_tensor["agent_ref"][i] + elif default_agent_ref is not None: + agent_ref = default_agent_ref + else: + agent_ref = {"type": "responses_api_agents", "name": "math_with_judge_simple_agent"} + + # Build responses_create_params + rcp = {} + if "extra_env_info" in non_tensor and "_rcp_extra" in (non_tensor["extra_env_info"][i] or {}): + rcp.update(non_tensor["extra_env_info"][i]["_rcp_extra"]) + rcp.update({"input": messages, "temperature": temperature, "top_p": top_p}) + + row = { + "responses_create_params": rcp, + "agent_ref": agent_ref, + "_rowidx": i, + } + + if "extra_env_info" in non_tensor: + env_info = non_tensor["extra_env_info"][i] + if isinstance(env_info, dict): + for k, v in env_info.items(): + if k not in row: + row[k] = v + + examples.append(row) + return examples + + +def _replace_prefix_tokens( + tokenizer, + model_prefix_token_ids: list[int], + template_prefix_token_ids: list[int], + template_token_ids: list[int], +) -> list[int]: + """Fix chat-template re-tokenization differences across multi-turn calls. + + Find the first eos token that appears at or after the first + divergence point between model_prefix and template. That eos marks the + end of the re-tokenized assistant turn. Everything from that eos onward + in template_token_ids (eos + tool-result + gen-prompt) becomes the new + observation suffix, and the model's original prefix tokens are kept intact. + + TODO: review why original was not sufficient (re-tokenized turn can be shorter or longer) + TODO: verl avoids this entirely by appending token ids each turn without re-tokenizing + """ + if not model_prefix_token_ids: + return template_token_ids + raw_eos = tokenizer.eos_token_id + if raw_eos is None: + return template_token_ids + # handle case where eos_token_id is list (ideally there is just 1 lol) + eos_ids: set[int] = set(raw_eos) if isinstance(raw_eos, list) else {raw_eos} + + model_cut = len(model_prefix_token_ids) + if model_prefix_token_ids[-1] in eos_ids: + model_cut -= 1 + + # Find first position where the two sequences differ + first_diff = next( + (i for i, (a, b) in enumerate(zip(model_prefix_token_ids, template_token_ids, strict=False)) if a != b), + min(len(model_prefix_token_ids), len(template_token_ids)), + ) + + # look ahead from first_diff: find the eos that ends the re-tokenized assistant turn. + # works even when the re-tokenized turn is shorter than the original. + cut = -1 + for pos in range(first_diff, len(template_token_ids)): + if template_token_ids[pos] in eos_ids: + cut = pos + break + + # look back up to first_diff (for uncommon longer retokenized case) + if cut < 0: + for pos in reversed(range(first_diff)): + if template_token_ids[pos] in eos_ids: + cut = pos + break + + if cut < 0: + return template_token_ids + return model_prefix_token_ids[:model_cut] + template_token_ids[cut:] + + +def _postprocess_nemo_gym_result(nemo_gym_result: dict, tokenizer) -> dict: + message_log = [] + seen_token_ids: list[int] = [] + + for item in nemo_gym_result["response"]["output"]: + if "generation_token_ids" not in item: + continue + + prompt_ids = item["prompt_token_ids"] + + # If token IDs are non-contiguous apply _replace_prefix_tokens fix to restore monotonic token IDs. + # TODO use verl way + if seen_token_ids and seen_token_ids != prompt_ids[: len(seen_token_ids)]: + prompt_ids = _replace_prefix_tokens( + tokenizer, + model_prefix_token_ids=seen_token_ids, + template_prefix_token_ids=seen_token_ids, + template_token_ids=prompt_ids, + ) + + if seen_token_ids != prompt_ids[: len(seen_token_ids)]: + diverge_at = next( + (i for i, (a, b) in enumerate(zip(seen_token_ids, prompt_ids, strict=False)) if a != b), + len(seen_token_ids), + ) + raise AssertionError( + f"Non-contiguous token IDs after replace_prefix fix. " + f"seen_len={len(seen_token_ids)} prompt_len={len(prompt_ids)} " + f"first_diverge={diverge_at} " + f"seen[{diverge_at - 2}:{diverge_at + 3}]={seen_token_ids[max(0, diverge_at - 2) : diverge_at + 3]} " + f"prompt[{diverge_at - 2}:{diverge_at + 3}]={prompt_ids[max(0, diverge_at - 2) : diverge_at + 3]}" + ) + + message_log.append( + { + "role": "user", + "content": "", + "token_ids": torch.tensor(prompt_ids[len(seen_token_ids) :]), + } + ) + message_log.append( + { + "role": "assistant", + "content": "", + "token_ids": torch.tensor(item["generation_token_ids"]), + "generation_logprobs": torch.tensor(item["generation_log_probs"]), + } + ) + + seen_token_ids.extend(message_log[-2]["token_ids"].tolist()) + seen_token_ids.extend(message_log[-1]["token_ids"].tolist()) + + item.pop("prompt_token_ids", None) + item["prompt_str"] = tokenizer.decode(prompt_ids) + item["generation_str"] = tokenizer.decode(item.pop("generation_token_ids")) + item.pop("generation_log_probs") + + if not message_log: + raise ValueError( + "nemo-gym returned a result with no generation data. " + "The prompt (ie full context in multi step/turn) probably exceeds vLLM's max_model_len." + ) + + return { + "message_log": message_log, + "input_message_log": message_log[:1], + "full_result": nemo_gym_result, + } + + +def _empty_result(nemo_gym_row: dict, tokenizer) -> dict: + """empty result for overlong samples + TODO: should we truncate or something else? what is best practice here?""" + + messages = nemo_gym_row.get("responses_create_params", {}).get("input", []) + raw_prompt = [{"role": m.get("role", "user"), "content": m.get("content", "")} for m in messages] + prompt_ids = tokenizer.apply_chat_template(raw_prompt, tokenize=True, add_generation_prompt=False)[-1:] + dummy_tok = torch.tensor(prompt_ids, dtype=torch.long) + return { + "message_log": [ + {"role": "user", "token_ids": dummy_tok}, + {"role": "assistant", "token_ids": dummy_tok, "generation_logprobs": torch.zeros(len(dummy_tok))}, + ], + "input_message_log": [{"role": "user", "token_ids": dummy_tok}], + "full_result": {"reward": 0.0}, + } + + +def _nemo_gym_result_to_verl( + result: dict, + tokenizer, + prompt_length: int, + response_length: int, +) -> _InternalAgentLoopOutput: + """Pack message_log into padded tensors for verl and mask non assistant messages""" + message_log = result["message_log"] + input_message_log = result["input_message_log"] + + prompt_ids_raw: list[int] = [] + for msg in input_message_log: + tids = msg["token_ids"] + prompt_ids_raw.extend(tids.tolist() if isinstance(tids, torch.Tensor) else tids) + + n_prompt_msgs = len(input_message_log) + response_ids_raw: list[int] = [] + response_mask_raw: list[int] = [] + response_logprobs_raw: list[float] = [] + + for msg in message_log[n_prompt_msgs:]: + tids = msg["token_ids"] + toks = tids.tolist() if isinstance(tids, torch.Tensor) else list(tids) + if msg["role"] == "assistant": + response_ids_raw.extend(toks) + response_mask_raw.extend([1] * len(toks)) + lp = msg.get("generation_logprobs") + response_logprobs_raw.extend( + lp.tolist() if isinstance(lp, torch.Tensor) else list(lp) if lp is not None else [0.0] * len(toks) + ) + else: + response_ids_raw.extend(toks) + response_mask_raw.extend([0] * len(toks)) + response_logprobs_raw.extend([0.0] * len(toks)) + + response_ids_raw = response_ids_raw[:response_length] + response_mask_raw = response_mask_raw[:response_length] + response_logprobs_raw = response_logprobs_raw[:response_length] + + tokenizer.padding_side = "left" + prompt_out = tokenizer.pad( + {"input_ids": prompt_ids_raw}, + padding="max_length", + max_length=prompt_length, + return_tensors="pt", + return_attention_mask=True, + ) + prompt_ids = prompt_out["input_ids"] + if prompt_ids.dim() == 1: + prompt_ids = prompt_ids.unsqueeze(0) + prompt_out["attention_mask"] = prompt_out["attention_mask"].unsqueeze(0) + + tokenizer.padding_side = "right" + response_out = tokenizer.pad( + {"input_ids": response_ids_raw}, + padding="max_length", + max_length=response_length, + return_tensors="pt", + return_attention_mask=True, + ) + response_ids = response_out["input_ids"] + response_attn = response_out["attention_mask"] + if response_ids.dim() == 1: + response_ids = response_ids.unsqueeze(0) + response_attn = response_attn.unsqueeze(0) + + pad = response_length - len(response_mask_raw) + response_mask = torch.tensor(response_mask_raw + [0] * pad, dtype=torch.long).unsqueeze(0) * response_attn + + pad = response_length - len(response_logprobs_raw) + response_logprobs = torch.tensor(response_logprobs_raw + [0.0] * pad, dtype=torch.float32).unsqueeze(0) + + attention_mask = torch.cat([prompt_out["attention_mask"], response_attn], dim=1) + input_ids = torch.cat([prompt_ids, response_ids], dim=1) + + position_ids = compute_position_id_with_mask(attention_mask) + + reward_score: Optional[float] = result.get("full_result", {}).get("reward", None) + num_turns = sum(1 for m in message_log if m["role"] == "assistant") + + return _InternalAgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids, + input_ids=input_ids, + position_ids=position_ids, + response_mask=response_mask, + attention_mask=attention_mask, + response_logprobs=response_logprobs, + routed_experts=None, + multi_modal_inputs={}, + teacher_logprobs=None, + teacher_ids=None, + reward_score=reward_score, + num_turns=num_turns, + metrics=AgentLoopMetrics(), + extra_fields={}, + ) + + +def _compute_rollout_metrics(results: list[dict], max_model_len: Optional[int] = None) -> dict: + batch_size = len(results) + if batch_size == 0: + return {} + + def _mean(vals): + return sum(vals) / len(vals) if vals else 0.0 + + stats = [] + for r in results: + ml = r["message_log"] + total = sum(len(m["token_ids"]) for m in ml) + asst = sum(len(m["token_ids"]) for m in ml if m["role"] == "assistant") + turns = sum(1 for m in ml if m["role"] == "user") + hit_max = (max_model_len is not None) and (total >= max_model_len) + stats.append( + { + "reward": float(r.get("full_result", {}).get("reward", 0.0)), + "asst": asst, + "total": total, + "turns": turns, + "hit_max": hit_max, + } + ) + + return { + "turns_per_sample/mean": _mean([s["turns"] for s in stats]), + "total_tokens_per_sample/mean": _mean([s["total"] for s in stats]), + "gen_tokens_per_sample/mean": _mean([s["asst"] for s in stats]), + "mean_gen_tokens_per_sample": _mean([s["asst"] for s in stats]), + "total_reward/mean": _mean([s["reward"] for s in stats]), + "natural_termination_rate": sum(not s["hit_max"] for s in stats) / batch_size, + "truncation_rate": sum(s["hit_max"] for s in stats) / batch_size, + } diff --git a/verl/experimental/nemo_gym_dataset.py b/verl/experimental/nemo_gym_dataset.py new file mode 100644 index 00000000000..f3bdc4ed228 --- /dev/null +++ b/verl/experimental/nemo_gym_dataset.py @@ -0,0 +1,84 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from pathlib import Path + +import torch +from torch.utils.data import Dataset + + +class NemoGymJSONLDataset(Dataset): + def __init__( + self, + data_files: str | list[str], + tokenizer, + processor=None, + config=None, + **kwargs, + ): + if isinstance(data_files, str): + data_files = [data_files] + + self.tokenizer = tokenizer + self._rows: list[dict] = [] + + for path in data_files: + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"NemoGymJSONLDataset: file not found: {path}") + with open(path) as f: + for line in f: + line = line.strip() + if line: + self._rows.append(json.loads(line)) + + def __len__(self) -> int: + return len(self._rows) + + def __getitem__(self, idx: int) -> dict: + row = self._rows[idx] + + rcp = row.get("responses_create_params", {}) + messages = rcp.get("input", []) + raw_prompt = [{"role": m.get("role", "user"), "content": m.get("content", "")} for m in messages] + + agent_ref = row.get("agent_ref", None) + + skip_keys = {"responses_create_params", "agent_ref"} + extra_env_info = {k: v for k, v in row.items() if k not in skip_keys} + + # preserve per-task fields like `tools` to nemo-gym request + # input/temperature/top_p/top_k are overridden per training step + # parallel_tool_calls is dropped because vLLM throws 500 error + _rcp_skip = {"input", "temperature", "top_p", "top_k", "parallel_tool_calls"} + rcp_extra = {k: v for k, v in rcp.items() if k not in _rcp_skip} + if rcp_extra: + extra_env_info["_rcp_extra"] = rcp_extra + + out = { + "raw_prompt": raw_prompt, + # unused placeholder. ray_trainer.py calls len(batch.batch) which requires batch to be non-empty + "__nemo_gym_batch_size__": torch.zeros(1, dtype=torch.long), + } + if agent_ref is not None: + out["agent_ref"] = agent_ref + out["extra_env_info"] = extra_env_info + + return out + + @property + def collate_fn(self): + from verl.utils.dataset.rl_dataset import collate_fn + + return collate_fn diff --git a/verl/experimental/tool_parsers/hermes_tool_parser_patched.py b/verl/experimental/tool_parsers/hermes_tool_parser_patched.py new file mode 100644 index 00000000000..f697863e154 --- /dev/null +++ b/verl/experimental/tool_parsers/hermes_tool_parser_patched.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +try: + from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager +except ImportError: + from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager + +try: + from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser +except ImportError: + from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser + + +@ToolParserManager.register_module("hermes_patched") +class HermesPatchedToolParser(Hermes2ProToolParser): + def __init__(self, tokenizer): + ToolParser.__init__(self, tokenizer) + + try: + from vllm.transformers_utils.tokenizer import MistralTokenizer + + if isinstance(self.model_tokenizer, MistralTokenizer): + self.model_tokenizer = self.model_tokenizer.tokenizer + except ImportError: + pass + + self.current_tool_name_sent = False + self.prev_tool_call_arr = [] + self.current_tool_id = -1 + self.streamed_args_for_tool = [] + self.tool_call_start_token = "" + self.tool_call_end_token = "" + self.tool_call_regex = re.compile(r"(.*?)|(.*)", re.DOTALL) + self.scratch_pad_regex = re.compile(r"(.*?)", re.DOTALL) + + vocab = self.model_tokenizer.get_vocab() + start_id = vocab.get(self.tool_call_start_token) + end_id = vocab.get(self.tool_call_end_token) + + self.tool_call_start_token_ids = [start_id] if start_id is not None else [] + self.tool_call_end_token_ids = [end_id] if end_id is not None else [] + self.tool_call_start_token_array = [self.tool_call_start_token] + self.tool_call_end_token_array = [self.tool_call_end_token] + self.buffered_delta_text = "" diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index 886cb1f836e..328d90d2a7f 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings from dataclasses import dataclass, field -from typing import Optional +from typing import Any, Optional from omegaconf import MISSING @@ -87,6 +87,8 @@ class AgentLoopConfig(BaseConfig): # Fully qualified class name for custom AgentLoopManager (e.g., "mypackage.module.MyManager"). # Security: This class will be dynamically imported via importlib. Only use trusted class paths. agent_loop_manager_class: Optional[str] = None + # nemo-gym config (nemo_gym_root, initial_global_config_dict, uses_reasoning_parser, etc.) + nemo_gym: Optional[Any] = None @dataclass diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index ff37780cff2..444bd9ddf69 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -397,6 +397,19 @@ async def run_server(self, args: argparse.Namespace): method="monkey_patch_model", kwargs={"vocab_size": len(self.model_config.tokenizer)} ) + # Load custom tool parser plugin before build_app so it's registered before any requests arrive + _tool_parser_plugin = getattr(args, "tool_parser_plugin", None) + if _tool_parser_plugin: + try: + try: + from vllm.tool_parsers import ToolParserManager + except ImportError: + from vllm.entrypoints.openai.tool_parsers import ToolParserManager + ToolParserManager.import_tool_parser(_tool_parser_plugin) + logger.info(f"Loaded tool parser plugin: {_tool_parser_plugin}") + except Exception as e: + logger.warning(f"Failed to load tool parser plugin {_tool_parser_plugin}: {e}") + build_app_sig = inspect.signature(build_app) supported_tasks: tuple[Any, ...] = () if "supported_tasks" in build_app_sig.parameters: From 896bd911401605cb3378d76046d8571dd0f28158 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 20:22:21 -0700 Subject: [PATCH 02/19] remove path stuff Signed-off-by: cmunley1 --- submit_math.sh | 1 - submit_workplace.sh | 1 - verl/experimental/nemo_gym_agent_loop.py | 10 ---------- 3 files changed, 12 deletions(-) diff --git a/submit_math.sh b/submit_math.sh index cff7f20b915..eb64d6beb0d 100755 --- a/submit_math.sh +++ b/submit_math.sh @@ -87,7 +87,6 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ RAY_ADDRESS="auto" \ VLLM_USE_V1=1 \ TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ - PYTHONPATH="${NEMO_GYM_ROOT}" \ RAY_grpc_keepalive_time_ms=60000 \ RAY_grpc_keepalive_timeout_ms=600000 \ RAY_grpc_client_keepalive_time_ms=60000 \ diff --git a/submit_workplace.sh b/submit_workplace.sh index ec85a6bab74..697e26df4e5 100755 --- a/submit_workplace.sh +++ b/submit_workplace.sh @@ -88,7 +88,6 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ VLLM_USE_V1=1 \ TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ - PYTHONPATH="${NEMO_GYM_ROOT}" \ RAY_grpc_keepalive_time_ms=60000 \ RAY_grpc_keepalive_timeout_ms=600000 \ RAY_grpc_client_keepalive_time_ms=60000 \ diff --git a/verl/experimental/nemo_gym_agent_loop.py b/verl/experimental/nemo_gym_agent_loop.py index 40882951f69..d70a606c6d3 100644 --- a/verl/experimental/nemo_gym_agent_loop.py +++ b/verl/experimental/nemo_gym_agent_loop.py @@ -16,7 +16,6 @@ import asyncio import os import socket -import sys import threading from typing import Optional @@ -64,12 +63,7 @@ async def create( async def _init_nemo_gym(self) -> None: nemo_gym_cfg = self.rollout_config.agent.get("nemo_gym", {}) - # PYTHONPATH env var isn't reliable inside Ray actor processes (already running - # when env var is set); sys.path injection works instead. - # TODO: seems hacky. proper method maybe is pip install? but thats a bit broken nemo_gym_root = nemo_gym_cfg.get("nemo_gym_root", None) - if nemo_gym_root and str(nemo_gym_root) not in sys.path: - sys.path.insert(0, str(nemo_gym_root)) try: from nemo_gym.cli import GlobalConfigDictParserConfig, RunHelper @@ -124,10 +118,6 @@ async def _init_nemo_gym(self) -> None: initial_global_cfg[HEAD_SERVER_KEY_NAME] = {"host": "0.0.0.0", "port": head_port} self._head_server_config = BaseServerConfig(host=node_ip, port=head_port) - if nemo_gym_root: - existing = os.environ.get("PYTHONPATH", "") - os.environ["PYTHONPATH"] = f"{nemo_gym_root}:{existing}" if existing else str(nemo_gym_root) - # Auto-detect agent ref. maybe dangerous for multi-environment # TODO test multienv self._default_agent_ref = None From f265516a65c058c5998ca00ee7ee15636164a1b0 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 20:31:36 -0700 Subject: [PATCH 03/19] remove toolparser, make nemogym folder Signed-off-by: cmunley1 --- .gitignore | 4 +- submit_math.sh | 2 +- submit_workplace.sh | 2 +- verl/experimental/nemo_gym/__init__.py | 0 .../agent_loop.py} | 8 +-- .../experimental/nemo_gym/config.env.example | 0 .../dataset.py} | 0 .../hermes_tool_parser_patched.py | 57 ------------------- verl/workers/config/rollout.py | 11 +++- .../rollout/vllm_rollout/vllm_async_server.py | 13 ----- 10 files changed, 16 insertions(+), 81 deletions(-) create mode 100644 verl/experimental/nemo_gym/__init__.py rename verl/experimental/{nemo_gym_agent_loop.py => nemo_gym/agent_loop.py} (98%) rename config.env.example => verl/experimental/nemo_gym/config.env.example (100%) rename verl/experimental/{nemo_gym_dataset.py => nemo_gym/dataset.py} (100%) delete mode 100644 verl/experimental/tool_parsers/hermes_tool_parser_patched.py diff --git a/.gitignore b/.gitignore index 0aa648b904e..e6c0f5a08e3 100644 --- a/.gitignore +++ b/.gitignore @@ -129,6 +129,4 @@ ENV/ logs log outputs -.history/ -logs/ -config.env +.history \ No newline at end of file diff --git a/submit_math.sh b/submit_math.sh index eb64d6beb0d..898d27410d2 100755 --- a/submit_math.sh +++ b/submit_math.sh @@ -161,7 +161,7 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ trainer.default_local_dir="${CKPTS_DIR}" \ trainer.resume_mode=auto \ trainer.log_val_generations=10 \ - +actor_rollout_ref.rollout.agent.agent_loop_manager_class='verl.experimental.nemo_gym_agent_loop.NemoGymAgentLoopManager' \ + +actor_rollout_ref.rollout.agent.agent_loop_manager_class='verl.experimental.nemo_gym.agent_loop.NemoGymAgentLoopManager' \ "+actor_rollout_ref.rollout.agent.nemo_gym.initial_global_config_dict.config_paths=[${NEMO_GYM_ROOT}/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,${NEMO_GYM_ROOT}/resources_servers/math_with_judge/configs/math_with_judge.yaml]" \ +actor_rollout_ref.rollout.agent.nemo_gym.uses_reasoning_parser=False \ +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root="${NEMO_GYM_ROOT}" \ diff --git a/submit_workplace.sh b/submit_workplace.sh index 697e26df4e5..9cc24b10ddb 100755 --- a/submit_workplace.sh +++ b/submit_workplace.sh @@ -164,7 +164,7 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ trainer.default_local_dir="${CKPTS_DIR}" \ trainer.resume_mode=auto \ trainer.log_val_generations=10 \ - +actor_rollout_ref.rollout.agent.agent_loop_manager_class='verl.experimental.nemo_gym_agent_loop.NemoGymAgentLoopManager' \ + +actor_rollout_ref.rollout.agent.agent_loop_manager_class='verl.experimental.nemo_gym.agent_loop.NemoGymAgentLoopManager' \ "+actor_rollout_ref.rollout.agent.nemo_gym.initial_global_config_dict.config_paths=[${NEMO_GYM_ROOT}/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,${NEMO_GYM_ROOT}/resources_servers/workplace_assistant/configs/workplace_assistant.yaml]" \ +actor_rollout_ref.rollout.agent.nemo_gym.uses_reasoning_parser=False \ +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root="${NEMO_GYM_ROOT}" \ diff --git a/verl/experimental/nemo_gym/__init__.py b/verl/experimental/nemo_gym/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/verl/experimental/nemo_gym_agent_loop.py b/verl/experimental/nemo_gym/agent_loop.py similarity index 98% rename from verl/experimental/nemo_gym_agent_loop.py rename to verl/experimental/nemo_gym/agent_loop.py index d70a606c6d3..746f7c8beb1 100644 --- a/verl/experimental/nemo_gym_agent_loop.py +++ b/verl/experimental/nemo_gym/agent_loop.py @@ -61,9 +61,9 @@ async def create( return instance async def _init_nemo_gym(self) -> None: - nemo_gym_cfg = self.rollout_config.agent.get("nemo_gym", {}) + nemo_gym_cfg = self.rollout_config.agent.nemo_gym - nemo_gym_root = nemo_gym_cfg.get("nemo_gym_root", None) + nemo_gym_root = nemo_gym_cfg.nemo_gym_root if nemo_gym_cfg else None try: from nemo_gym.cli import GlobalConfigDictParserConfig, RunHelper @@ -78,9 +78,9 @@ async def _init_nemo_gym(self) -> None: ) from e raise - initial_global_cfg = dict(nemo_gym_cfg.get("initial_global_config_dict", {})) + initial_global_cfg = dict(nemo_gym_cfg.initial_global_config_dict or {}) if nemo_gym_cfg else {} - uses_reasoning_parser = bool(nemo_gym_cfg.get("uses_reasoning_parser", False)) + uses_reasoning_parser = nemo_gym_cfg.uses_reasoning_parser if nemo_gym_cfg else False vllm_model_cfg = ( initial_global_cfg.setdefault("policy_model", {}) .setdefault("responses_api_models", {}) diff --git a/config.env.example b/verl/experimental/nemo_gym/config.env.example similarity index 100% rename from config.env.example rename to verl/experimental/nemo_gym/config.env.example diff --git a/verl/experimental/nemo_gym_dataset.py b/verl/experimental/nemo_gym/dataset.py similarity index 100% rename from verl/experimental/nemo_gym_dataset.py rename to verl/experimental/nemo_gym/dataset.py diff --git a/verl/experimental/tool_parsers/hermes_tool_parser_patched.py b/verl/experimental/tool_parsers/hermes_tool_parser_patched.py deleted file mode 100644 index f697863e154..00000000000 --- a/verl/experimental/tool_parsers/hermes_tool_parser_patched.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import re - -try: - from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager -except ImportError: - from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager - -try: - from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser -except ImportError: - from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser - - -@ToolParserManager.register_module("hermes_patched") -class HermesPatchedToolParser(Hermes2ProToolParser): - def __init__(self, tokenizer): - ToolParser.__init__(self, tokenizer) - - try: - from vllm.transformers_utils.tokenizer import MistralTokenizer - - if isinstance(self.model_tokenizer, MistralTokenizer): - self.model_tokenizer = self.model_tokenizer.tokenizer - except ImportError: - pass - - self.current_tool_name_sent = False - self.prev_tool_call_arr = [] - self.current_tool_id = -1 - self.streamed_args_for_tool = [] - self.tool_call_start_token = "" - self.tool_call_end_token = "" - self.tool_call_regex = re.compile(r"(.*?)|(.*)", re.DOTALL) - self.scratch_pad_regex = re.compile(r"(.*?)", re.DOTALL) - - vocab = self.model_tokenizer.get_vocab() - start_id = vocab.get(self.tool_call_start_token) - end_id = vocab.get(self.tool_call_end_token) - - self.tool_call_start_token_ids = [start_id] if start_id is not None else [] - self.tool_call_end_token_ids = [end_id] if end_id is not None else [] - self.tool_call_start_token_array = [self.tool_call_start_token] - self.tool_call_end_token_array = [self.tool_call_end_token] - self.buffered_delta_text = "" diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index 328d90d2a7f..d6bcb29f415 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -78,6 +78,14 @@ class CustomAsyncServerConfig(BaseConfig): name: Optional[str] = None +@dataclass +class NemoGymConfig(BaseConfig): + nemo_gym_root: Optional[str] = None + uses_reasoning_parser: bool = False + # Passthrough to nemo-gym's GlobalConfigDictParserConfig — verl does not own this schema. + initial_global_config_dict: Optional[Any] = None + + @dataclass class AgentLoopConfig(BaseConfig): num_workers: int = 8 @@ -87,8 +95,7 @@ class AgentLoopConfig(BaseConfig): # Fully qualified class name for custom AgentLoopManager (e.g., "mypackage.module.MyManager"). # Security: This class will be dynamically imported via importlib. Only use trusted class paths. agent_loop_manager_class: Optional[str] = None - # nemo-gym config (nemo_gym_root, initial_global_config_dict, uses_reasoning_parser, etc.) - nemo_gym: Optional[Any] = None + nemo_gym: Optional[NemoGymConfig] = None @dataclass diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 444bd9ddf69..ff37780cff2 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -397,19 +397,6 @@ async def run_server(self, args: argparse.Namespace): method="monkey_patch_model", kwargs={"vocab_size": len(self.model_config.tokenizer)} ) - # Load custom tool parser plugin before build_app so it's registered before any requests arrive - _tool_parser_plugin = getattr(args, "tool_parser_plugin", None) - if _tool_parser_plugin: - try: - try: - from vllm.tool_parsers import ToolParserManager - except ImportError: - from vllm.entrypoints.openai.tool_parsers import ToolParserManager - ToolParserManager.import_tool_parser(_tool_parser_plugin) - logger.info(f"Loaded tool parser plugin: {_tool_parser_plugin}") - except Exception as e: - logger.warning(f"Failed to load tool parser plugin {_tool_parser_plugin}: {e}") - build_app_sig = inspect.signature(build_app) supported_tasks: tuple[Any, ...] = () if "supported_tasks" in build_app_sig.parameters: From 47fc7769c65231e003a8f8cec695f4f2cff56dad Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 20:34:33 -0700 Subject: [PATCH 04/19] dont except omegaconf Signed-off-by: cmunley1 --- verl/experimental/nemo_gym/agent_loop.py | 12 +++++------- verl/workers/config/rollout.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/verl/experimental/nemo_gym/agent_loop.py b/verl/experimental/nemo_gym/agent_loop.py index 746f7c8beb1..86c31c0b3ee 100644 --- a/verl/experimental/nemo_gym/agent_loop.py +++ b/verl/experimental/nemo_gym/agent_loop.py @@ -65,18 +65,16 @@ async def _init_nemo_gym(self) -> None: nemo_gym_root = nemo_gym_cfg.nemo_gym_root if nemo_gym_cfg else None + from omegaconf import DictConfig + try: from nemo_gym.cli import GlobalConfigDictParserConfig, RunHelper from nemo_gym.rollout_collection import RolloutCollectionHelper from nemo_gym.server_utils import HEAD_SERVER_KEY_NAME, BaseServerConfig - from omegaconf import DictConfig except ModuleNotFoundError as e: - if "nemo_gym" in str(e): - raise ImportError( - "nemo-gym not found. Set nemo_gym_root in config: " - "+actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root=/path/to/gym-ref" - ) from e - raise + raise ImportError( + "nemo-gym not found. Install it with: pip install -e /path/to/gym-ref" + ) from e initial_global_cfg = dict(nemo_gym_cfg.initial_global_config_dict or {}) if nemo_gym_cfg else {} diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index d6bcb29f415..e952d0b9d1f 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -82,7 +82,7 @@ class CustomAsyncServerConfig(BaseConfig): class NemoGymConfig(BaseConfig): nemo_gym_root: Optional[str] = None uses_reasoning_parser: bool = False - # Passthrough to nemo-gym's GlobalConfigDictParserConfig — verl does not own this schema. + # uses NeMo Gym's GlobalConfigDictParserConfig initial_global_config_dict: Optional[Any] = None From 0f85db85f02b8c36f4042165874bbd4c7e144c1d Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 20:46:04 -0700 Subject: [PATCH 05/19] updates Signed-off-by: cmunley1 --- docs/examples/nemo_gym.rst | 75 ++++++++++++++++++++++++ docs/index.rst | 1 + submit_math.sh | 2 + submit_workplace.sh | 2 + verl/experimental/nemo_gym/agent_loop.py | 2 +- verl/experimental/nemo_gym/dataset.py | 2 +- 6 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 docs/examples/nemo_gym.rst diff --git a/docs/examples/nemo_gym.rst b/docs/examples/nemo_gym.rst new file mode 100644 index 00000000000..452386c38be --- /dev/null +++ b/docs/examples/nemo_gym.rst @@ -0,0 +1,75 @@ +Multi-Turn Tool Use with NeMo Gym +================================== + +`NeMo Gym `_ is an environment framework for +multi-turn RL rollouts with tool-calling agents. The verl integration lets you +replace verl's standard single-turn rollout with NeMo Gym's ``simple_agent``, +which runs full multi-turn conversations through an OpenAI-compatible HTTP +endpoint and returns token IDs and log-probs for training. + +Overview +-------- + +The integration adds two components to ``verl/experimental/nemo_gym/``: + +- ``agent_loop.py`` — ``NemoGymAgentLoopManager``: drives multi-turn rollouts + via NeMo Gym, handles token ID reconciliation across turns, and returns a + ``DataProto`` compatible with verl's Megatron actor update. +- ``dataset.py`` — ``NemoGymJSONLDataset``: loads NeMo Gym JSONL files + (including tool definitions, agent refs, and ground-truth answers) into + verl's data pipeline. + +Requirements +------------ + +- A NeMo Gym checkout (``gym-ref``) with the environment you want to train on. +- ``pip install -e /path/to/gym-ref`` installed into the container at job start. + +Quick Start +----------- + +1. **Install NeMo Gym** in your container startup script:: + + pip install -e /path/to/gym-ref + +2. **Prepare your dataset** in NeMo Gym JSONL format. Each line should be a + JSON object with a ``responses_create_params`` field containing the initial + messages and any tools, plus an ``agent_ref`` pointing at your environment's + agent server. + +3. **Add these overrides** to your verl training command:: + + +data.custom_cls.path=verl/experimental/nemo_gym/dataset.py + +data.custom_cls.name=NemoGymJSONLDataset + +actor_rollout_ref.rollout.agent.agent_loop_manager_class=verl.experimental.nemo_gym.agent_loop.NemoGymAgentLoopManager + "+actor_rollout_ref.rollout.agent.nemo_gym.initial_global_config_dict.config_paths=[/path/to/env.yaml]" + +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root=/path/to/gym-ref + +See ``submit_workplace.sh`` and ``submit_math.sh`` for complete working examples. + +Configuration +------------- + +The ``nemo_gym`` block in ``AgentLoopConfig`` accepts: + +.. code-block:: yaml + + actor_rollout_ref: + rollout: + agent: + nemo_gym: + nemo_gym_root: /path/to/gym-ref + uses_reasoning_parser: false + initial_global_config_dict: + config_paths: + - /path/to/env.yaml + +Tool Calling +------------ + +For environments that use tool calling (e.g. workplace assistant), pass the +vLLM engine kwargs to enable the hermes tool parser:: + + '+actor_rollout_ref.rollout.engine_kwargs.vllm.enable-auto-tool-choice=true' + '+actor_rollout_ref.rollout.engine_kwargs.vllm.tool-call-parser=hermes' + '+actor_rollout_ref.rollout.engine_kwargs.vllm.max-model-len=32768' diff --git a/docs/index.rst b/docs/index.rst index 381d3a6bad9..2bac7398448 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -63,6 +63,7 @@ verl is fast with: examples/gsm8k_example examples/multi_modal_example examples/skypilot_examples + examples/nemo_gym .. toctree:: :maxdepth: 1 diff --git a/submit_math.sh b/submit_math.sh index 898d27410d2..de365c44706 100755 --- a/submit_math.sh +++ b/submit_math.sh @@ -2,6 +2,8 @@ #SBATCH --job-name=verl-nemogym-dapo-7b-math #SBATCH --nodes=4 #SBATCH --ntasks-per-node=1 +#SBATCH --partition=your_partition +#SBATCH --account=your_account #SBATCH --time=4:00:00 #SBATCH --gres=gpu:8 #SBATCH --exclusive diff --git a/submit_workplace.sh b/submit_workplace.sh index 9cc24b10ddb..c8e985bb934 100755 --- a/submit_workplace.sh +++ b/submit_workplace.sh @@ -2,6 +2,8 @@ #SBATCH --job-name=verl-nemogym-dapo-4b-workplace #SBATCH --nodes=4 #SBATCH --ntasks-per-node=1 +#SBATCH --partition=your_partition +#SBATCH --account=your_account #SBATCH --time=4:00:00 #SBATCH --gres=gpu:8 #SBATCH --exclusive diff --git a/verl/experimental/nemo_gym/agent_loop.py b/verl/experimental/nemo_gym/agent_loop.py index 86c31c0b3ee..63389f23e32 100644 --- a/verl/experimental/nemo_gym/agent_loop.py +++ b/verl/experimental/nemo_gym/agent_loop.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/verl/experimental/nemo_gym/dataset.py b/verl/experimental/nemo_gym/dataset.py index f3bdc4ed228..ab29089f170 100644 --- a/verl/experimental/nemo_gym/dataset.py +++ b/verl/experimental/nemo_gym/dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 7b2a147d2cc73d1d4ca3c17f0ef31a0b234bf6bb Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 20:46:36 -0700 Subject: [PATCH 06/19] docs Signed-off-by: cmunley1 --- docs/examples/nemo_gym.rst | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/docs/examples/nemo_gym.rst b/docs/examples/nemo_gym.rst index 452386c38be..30d0b7fcc9d 100644 --- a/docs/examples/nemo_gym.rst +++ b/docs/examples/nemo_gym.rst @@ -1,28 +1,25 @@ Multi-Turn Tool Use with NeMo Gym ================================== -`NeMo Gym `_ is an environment framework for -multi-turn RL rollouts with tool-calling agents. The verl integration lets you -replace verl's standard single-turn rollout with NeMo Gym's ``simple_agent``, -which runs full multi-turn conversations through an OpenAI-compatible HTTP -endpoint and returns token IDs and log-probs for training. +`NVIDIA NeMo Gym `_ is an RL environment framework for +scalable, multi-environment, agentic RL. This integration enables running NeMo Gym environments with +verl using a custom agent loop manager. Overview -------- The integration adds two components to ``verl/experimental/nemo_gym/``: -- ``agent_loop.py`` — ``NemoGymAgentLoopManager``: drives multi-turn rollouts - via NeMo Gym, handles token ID reconciliation across turns, and returns a - ``DataProto`` compatible with verl's Megatron actor update. -- ``dataset.py`` — ``NemoGymJSONLDataset``: loads NeMo Gym JSONL files - (including tool definitions, agent refs, and ground-truth answers) into - verl's data pipeline. +- ``agent_loop.py`` — ``NemoGymAgentLoopManager``: offloads multi-turn rollouts + to NeMo Gym, handles retokenization correction across turns, and formats output. + The retokenization logic may change shortly to follow verl approach. +- ``dataset.py`` — ``NemoGymJSONLDataset``: loads NeMo Gym datasets + including messages, tools, agent refs, and metadata into verl format. Requirements ------------ -- A NeMo Gym checkout (``gym-ref``) with the environment you want to train on. +- A NeMo Gym local clone (``gym-ref``) with the environment you want to train on. TODO finalize submodule decision - ``pip install -e /path/to/gym-ref`` installed into the container at job start. Quick Start @@ -32,7 +29,7 @@ Quick Start pip install -e /path/to/gym-ref -2. **Prepare your dataset** in NeMo Gym JSONL format. Each line should be a +2. **Prepare training datasets** in NeMo Gym JSONL format. Each line should be a JSON object with a ``responses_create_params`` field containing the initial messages and any tools, plus an ``agent_ref`` pointing at your environment's agent server. @@ -45,7 +42,7 @@ Quick Start "+actor_rollout_ref.rollout.agent.nemo_gym.initial_global_config_dict.config_paths=[/path/to/env.yaml]" +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root=/path/to/gym-ref -See ``submit_workplace.sh`` and ``submit_math.sh`` for complete working examples. +See ``submit_workplace.sh`` and ``submit_math.sh`` for working examples. Configuration ------------- @@ -67,8 +64,7 @@ The ``nemo_gym`` block in ``AgentLoopConfig`` accepts: Tool Calling ------------ -For environments that use tool calling (e.g. workplace assistant), pass the -vLLM engine kwargs to enable the hermes tool parser:: +For environments that use tool calling (e.g. workplace assistant), use a tool parser, for example:: '+actor_rollout_ref.rollout.engine_kwargs.vllm.enable-auto-tool-choice=true' '+actor_rollout_ref.rollout.engine_kwargs.vllm.tool-call-parser=hermes' From af9213648ed305f2ad5363e368de56d7ea564d3b Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 20:47:52 -0700 Subject: [PATCH 07/19] docs Signed-off-by: cmunley1 --- docs/examples/nemo_gym.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/examples/nemo_gym.rst b/docs/examples/nemo_gym.rst index 30d0b7fcc9d..dce990d7a81 100644 --- a/docs/examples/nemo_gym.rst +++ b/docs/examples/nemo_gym.rst @@ -1,9 +1,9 @@ Multi-Turn Tool Use with NeMo Gym ================================== -`NVIDIA NeMo Gym `_ is an RL environment framework for -scalable, multi-environment, agentic RL. This integration enables running NeMo Gym environments with -verl using a custom agent loop manager. +`NVIDIA NeMo Gym `_ (`docs `_) +is an RL environment framework for scalable, multi-environment, agentic RL. This integration enables +running NeMo Gym environments with verl using a custom agent loop manager. Overview -------- From 53ab67016035aed0c5f8969750a1049f95943364 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 20:48:34 -0700 Subject: [PATCH 08/19] docs Signed-off-by: cmunley1 --- docs/examples/nemo_gym.rst | 2 +- submit_math.sh | 2 +- submit_workplace.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/examples/nemo_gym.rst b/docs/examples/nemo_gym.rst index dce990d7a81..93ffec537ac 100644 --- a/docs/examples/nemo_gym.rst +++ b/docs/examples/nemo_gym.rst @@ -1,4 +1,4 @@ -Multi-Turn Tool Use with NeMo Gym +NVIDIA NeMo Gym Integration ================================== `NVIDIA NeMo Gym `_ (`docs `_) diff --git a/submit_math.sh b/submit_math.sh index de365c44706..01ec25b93f3 100755 --- a/submit_math.sh +++ b/submit_math.sh @@ -14,7 +14,7 @@ set -euo pipefail GPUS_PER_NODE=8 -source "$(dirname "$0")/config.env" +source "${SLURM_SUBMIT_DIR}/config.env" MODEL_PATH="${DATA_ROOT}/models/Qwen2.5-Math-7B" TRAIN_FILE="${NEMO_GYM_ROOT}/resources_servers/math_with_judge/data/dapo17k_bytedtsinghua_train_nrl.jsonl" diff --git a/submit_workplace.sh b/submit_workplace.sh index c8e985bb934..f6d6cad1058 100755 --- a/submit_workplace.sh +++ b/submit_workplace.sh @@ -14,7 +14,7 @@ set -euo pipefail GPUS_PER_NODE=8 -source "$(dirname "$0")/config.env" +source "${SLURM_SUBMIT_DIR}/config.env" MODEL_PATH="${HF_HOME}/hub/models--Qwen--Qwen3-4B-Instruct-2507/snapshots/cdbee75f17c01a7cc42f958dc650907174af0554" TRAIN_FILE="${DATA_ROOT}/workplace_assistant/train.jsonl" From 591d91248fea024c587581d888a05ac8a45a95f7 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 20:54:36 -0700 Subject: [PATCH 09/19] updates Signed-off-by: cmunley1 --- docs/examples/nemo_gym.rst | 2 +- submit_math.sh | 2 +- submit_workplace.sh | 2 +- verl/experimental/nemo_gym/agent_loop.py | 3 ++- verl/workers/config/rollout.py | 3 +-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/examples/nemo_gym.rst b/docs/examples/nemo_gym.rst index 93ffec537ac..bb72da41f00 100644 --- a/docs/examples/nemo_gym.rst +++ b/docs/examples/nemo_gym.rst @@ -39,7 +39,7 @@ Quick Start +data.custom_cls.path=verl/experimental/nemo_gym/dataset.py +data.custom_cls.name=NemoGymJSONLDataset +actor_rollout_ref.rollout.agent.agent_loop_manager_class=verl.experimental.nemo_gym.agent_loop.NemoGymAgentLoopManager - "+actor_rollout_ref.rollout.agent.nemo_gym.initial_global_config_dict.config_paths=[/path/to/env.yaml]" + "+actor_rollout_ref.rollout.agent.nemo_gym.config_paths=[/path/to/env.yaml]" +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root=/path/to/gym-ref See ``submit_workplace.sh`` and ``submit_math.sh`` for working examples. diff --git a/submit_math.sh b/submit_math.sh index 01ec25b93f3..ce54b8a1c16 100755 --- a/submit_math.sh +++ b/submit_math.sh @@ -164,7 +164,7 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ trainer.resume_mode=auto \ trainer.log_val_generations=10 \ +actor_rollout_ref.rollout.agent.agent_loop_manager_class='verl.experimental.nemo_gym.agent_loop.NemoGymAgentLoopManager' \ - "+actor_rollout_ref.rollout.agent.nemo_gym.initial_global_config_dict.config_paths=[${NEMO_GYM_ROOT}/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,${NEMO_GYM_ROOT}/resources_servers/math_with_judge/configs/math_with_judge.yaml]" \ + "+actor_rollout_ref.rollout.agent.nemo_gym.config_paths=[${NEMO_GYM_ROOT}/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,${NEMO_GYM_ROOT}/resources_servers/math_with_judge/configs/math_with_judge.yaml]" \ +actor_rollout_ref.rollout.agent.nemo_gym.uses_reasoning_parser=False \ +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root="${NEMO_GYM_ROOT}" \ 2>&1 diff --git a/submit_workplace.sh b/submit_workplace.sh index f6d6cad1058..1652b4fdb25 100755 --- a/submit_workplace.sh +++ b/submit_workplace.sh @@ -167,7 +167,7 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ trainer.resume_mode=auto \ trainer.log_val_generations=10 \ +actor_rollout_ref.rollout.agent.agent_loop_manager_class='verl.experimental.nemo_gym.agent_loop.NemoGymAgentLoopManager' \ - "+actor_rollout_ref.rollout.agent.nemo_gym.initial_global_config_dict.config_paths=[${NEMO_GYM_ROOT}/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,${NEMO_GYM_ROOT}/resources_servers/workplace_assistant/configs/workplace_assistant.yaml]" \ + "+actor_rollout_ref.rollout.agent.nemo_gym.config_paths=[${NEMO_GYM_ROOT}/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,${NEMO_GYM_ROOT}/resources_servers/workplace_assistant/configs/workplace_assistant.yaml]" \ +actor_rollout_ref.rollout.agent.nemo_gym.uses_reasoning_parser=False \ +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root="${NEMO_GYM_ROOT}" \ 2>&1 diff --git a/verl/experimental/nemo_gym/agent_loop.py b/verl/experimental/nemo_gym/agent_loop.py index 63389f23e32..b75256740ec 100644 --- a/verl/experimental/nemo_gym/agent_loop.py +++ b/verl/experimental/nemo_gym/agent_loop.py @@ -76,7 +76,8 @@ async def _init_nemo_gym(self) -> None: "nemo-gym not found. Install it with: pip install -e /path/to/gym-ref" ) from e - initial_global_cfg = dict(nemo_gym_cfg.initial_global_config_dict or {}) if nemo_gym_cfg else {} + config_paths = list(nemo_gym_cfg.config_paths or []) if nemo_gym_cfg else [] + initial_global_cfg = {"config_paths": config_paths} if config_paths else {} uses_reasoning_parser = nemo_gym_cfg.uses_reasoning_parser if nemo_gym_cfg else False vllm_model_cfg = ( diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index e952d0b9d1f..ee0b8111b61 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -82,8 +82,7 @@ class CustomAsyncServerConfig(BaseConfig): class NemoGymConfig(BaseConfig): nemo_gym_root: Optional[str] = None uses_reasoning_parser: bool = False - # uses NeMo Gym's GlobalConfigDictParserConfig - initial_global_config_dict: Optional[Any] = None + config_paths: Optional[list] = None @dataclass From 15380342c07180b3f6814d953e71ea03b2babd04 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 20:55:38 -0700 Subject: [PATCH 10/19] updates Signed-off-by: cmunley1 --- submit_math.sh | 2 +- submit_workplace.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/submit_math.sh b/submit_math.sh index ce54b8a1c16..272927b10a2 100755 --- a/submit_math.sh +++ b/submit_math.sh @@ -98,7 +98,7 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ --config-name=dapo_megatron_trainer.yaml \ data.train_files="${TRAIN_FILE}" \ data.val_files="${TEST_FILE}" \ - +data.custom_cls.path="${VERL_ROOT}/verl/experimental/nemo_gym_dataset.py" \ + +data.custom_cls.path="${VERL_ROOT}/verl/experimental/nemo_gym/dataset.py" \ +data.custom_cls.name=NemoGymJSONLDataset \ data.truncation=left \ data.max_prompt_length=2048 \ diff --git a/submit_workplace.sh b/submit_workplace.sh index 1652b4fdb25..8639a41fde5 100755 --- a/submit_workplace.sh +++ b/submit_workplace.sh @@ -99,7 +99,7 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ --config-name=dapo_megatron_trainer.yaml \ data.train_files="${TRAIN_FILE}" \ data.val_files="${TEST_FILE}" \ - +data.custom_cls.path="${VERL_ROOT}/verl/experimental/nemo_gym_dataset.py" \ + +data.custom_cls.path="${VERL_ROOT}/verl/experimental/nemo_gym/dataset.py" \ +data.custom_cls.name=NemoGymJSONLDataset \ data.truncation=left \ data.train_batch_size=512 \ From b7a0a5963f33b5b59f6a568e73373a6a5451cbf4 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 27 Mar 2026 21:44:42 -0700 Subject: [PATCH 11/19] data path Signed-off-by: cmunley1 --- submit_math.sh | 5 +++-- submit_workplace.sh | 1 + verl/experimental/nemo_gym/agent_loop.py | 6 ++++++ verl/experimental/nemo_gym/config.env.example | 2 -- 4 files changed, 10 insertions(+), 4 deletions(-) mode change 100755 => 100644 submit_workplace.sh diff --git a/submit_math.sh b/submit_math.sh index 272927b10a2..874f3ef33ab 100755 --- a/submit_math.sh +++ b/submit_math.sh @@ -17,8 +17,8 @@ GPUS_PER_NODE=8 source "${SLURM_SUBMIT_DIR}/config.env" MODEL_PATH="${DATA_ROOT}/models/Qwen2.5-Math-7B" -TRAIN_FILE="${NEMO_GYM_ROOT}/resources_servers/math_with_judge/data/dapo17k_bytedtsinghua_train_nrl.jsonl" -TEST_FILE="${NEMO_GYM_ROOT}/resources_servers/math_with_judge/data/aime24_bytedtsinghua_validation_nrl.jsonl" +TRAIN_FILE="${DATA_ROOT}/math_with_judge/dapo17k_bytedtsinghua_train_nrl.jsonl" +TEST_FILE="${DATA_ROOT}/math_with_judge/aime24_bytedtsinghua_validation_nrl.jsonl" CKPTS_DIR="${RESULTS_ROOT}/DAPO-Qwen2.5-7b-MATH-megatron" CONTAINER="verlai/verl:vllm017.latest" @@ -89,6 +89,7 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ RAY_ADDRESS="auto" \ VLLM_USE_V1=1 \ TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ + PYTHONPATH="${NEMO_GYM_ROOT}" \ RAY_grpc_keepalive_time_ms=60000 \ RAY_grpc_keepalive_timeout_ms=600000 \ RAY_grpc_client_keepalive_time_ms=60000 \ diff --git a/submit_workplace.sh b/submit_workplace.sh old mode 100755 new mode 100644 index 8639a41fde5..bcc3ca05df8 --- a/submit_workplace.sh +++ b/submit_workplace.sh @@ -89,6 +89,7 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ RAY_ADDRESS="auto" \ VLLM_USE_V1=1 \ TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ + PYTHONPATH="${NEMO_GYM_ROOT}" \ VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ RAY_grpc_keepalive_time_ms=60000 \ RAY_grpc_keepalive_timeout_ms=600000 \ diff --git a/verl/experimental/nemo_gym/agent_loop.py b/verl/experimental/nemo_gym/agent_loop.py index b75256740ec..5f948568be0 100644 --- a/verl/experimental/nemo_gym/agent_loop.py +++ b/verl/experimental/nemo_gym/agent_loop.py @@ -16,6 +16,7 @@ import asyncio import os import socket +import sys import threading from typing import Optional @@ -65,6 +66,11 @@ async def _init_nemo_gym(self) -> None: nemo_gym_root = nemo_gym_cfg.nemo_gym_root if nemo_gym_cfg else None + # Insert nemo_gym_root into sys.path so Ray remote actors (which don't + # inherit PYTHONPATH from the driver) can import nemo-gym from the shared fs. + if nemo_gym_root and str(nemo_gym_root) not in sys.path: + sys.path.insert(0, str(nemo_gym_root)) + from omegaconf import DictConfig try: diff --git a/verl/experimental/nemo_gym/config.env.example b/verl/experimental/nemo_gym/config.env.example index 0efe623aefc..6798fd72409 100644 --- a/verl/experimental/nemo_gym/config.env.example +++ b/verl/experimental/nemo_gym/config.env.example @@ -1,5 +1,3 @@ -# Copy this to config.env and fill in your values. config.env is gitignored. - VERL_ROOT=/path/to/verl NEMO_GYM_ROOT=/path/to/gym-ref HF_HOME=/path/to/hf_home From 902e83114a85c00bef30601e2ac5f27b9bad5863 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 28 Mar 2026 18:38:04 -0700 Subject: [PATCH 12/19] patch vllm server for nemogym retokenization fix Signed-off-by: cmunley1 --- verl/experimental/nemo_gym/agent_loop.py | 5 + verl/experimental/nemo_gym/server_patch.py | 99 +++++++++++++++++++ .../rollout/vllm_rollout/vllm_async_server.py | 5 + 3 files changed, 109 insertions(+) create mode 100644 verl/experimental/nemo_gym/server_patch.py diff --git a/verl/experimental/nemo_gym/agent_loop.py b/verl/experimental/nemo_gym/agent_loop.py index 5f948568be0..1111a602e7b 100644 --- a/verl/experimental/nemo_gym/agent_loop.py +++ b/verl/experimental/nemo_gym/agent_loop.py @@ -58,9 +58,14 @@ async def create( ) await instance._initialize_llm_servers() await instance._init_global_load_balancer() + await instance._apply_server_patch() await instance._init_nemo_gym() return instance + async def _apply_server_patch(self) -> None: + futs = [handle.apply_nemo_gym_server_patch.remote() for handle in self.server_handles] + await asyncio.get_event_loop().run_in_executor(None, ray.get, futs) + async def _init_nemo_gym(self) -> None: nemo_gym_cfg = self.rollout_config.agent.nemo_gym diff --git a/verl/experimental/nemo_gym/server_patch.py b/verl/experimental/nemo_gym/server_patch.py new file mode 100644 index 00000000000..ea86d5cfd75 --- /dev/null +++ b/verl/experimental/nemo_gym/server_patch.py @@ -0,0 +1,99 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +logger = logging.getLogger(__name__) + +def _replace_prefix_tokens(model_prefix, template_prefix, template_ids, tok): + eos = tok.eos_token_id + if eos is None or not model_prefix: + return template_ids + eos_set = set(eos) if isinstance(eos, list) else {eos} + cut_model = len(model_prefix) + if model_prefix[-1] in eos_set: + cut_model -= 1 + if len(template_ids) <= len(template_prefix): + return template_ids + cut = -1 + for pos in reversed(range(len(template_prefix))): + if template_ids[pos] in eos_set: + cut = pos + break + if cut < 0: + return template_ids + return model_prefix[:cut_model] + template_ids[cut:] + + +def patch_serving_chat_for_nemo_gym() -> None: + _serving_chat_cls = None + for _mod in ( + "vllm.entrypoints.openai.chat_completion.serving", + "vllm.entrypoints.openai.chat_completion", + "vllm.entrypoints.openai.api_server", + "vllm.entrypoints.openai.serving_chat", + ): + try: + import importlib + m = importlib.import_module(_mod) + if hasattr(m, "OpenAIServingChat"): + _serving_chat_cls = m.OpenAIServingChat + break + except ImportError: + continue + + if _serving_chat_cls is None: + logger.warning("[nemo-gym] could not find OpenAIServingChat; skipping retokenization patch.") + return + + OpenAIServingChat = _serving_chat_cls + _original_preprocess_chat = OpenAIServingChat._preprocess_chat + + async def _patched_preprocess_chat( + self, request, messages, + default_template, default_template_content_format, default_template_kwargs, + tool_dicts=None, tool_parser=None, + ): + required_prefix = getattr(request, "required_prefix_token_ids", None) + if required_prefix is None: + for msg in reversed(messages): + if isinstance(msg, dict) and "prompt_token_ids" in msg: + required_prefix = list(msg["prompt_token_ids"]) + list(msg["generation_token_ids"]) + break + elif not isinstance(msg, dict) and getattr(msg, "prompt_token_ids", None): + required_prefix = list(msg.prompt_token_ids) + list(msg.generation_token_ids) + break + + res = await _original_preprocess_chat( + self, request, messages, + default_template, default_template_content_format, default_template_kwargs, + tool_dicts=tool_dicts, tool_parser=tool_parser, + ) + + if required_prefix is None: + return res + + try: + tok = self.renderer.get_tokenizer() # avoid concurrent tokenizer access - else already borrowed error w/ hermes tool parser + engine_prompt = res[1][0] + engine_prompt["prompt_token_ids"] = _replace_prefix_tokens( + required_prefix, required_prefix, + engine_prompt["prompt_token_ids"], tok, + ) + except Exception as e: + logger.warning(f"[nemo-gym] retokenization patch failed, skipping: {e}") + return res + + OpenAIServingChat._preprocess_chat = _patched_preprocess_chat + logger.info("[nemo-gym] applied retokenization patch to OpenAIServingChat.") diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index ff37780cff2..6e071a80bae 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -170,6 +170,11 @@ def get_server_address(self): assert self._server_port is not None, "http server is not launched, port is None" return self._server_address, self._server_port + def apply_nemo_gym_server_patch(self): + # called by NemoGymAgentLoopManager to apply retokenization fix only for nemo-gym runs + from verl.experimental.nemo_gym.server_patch import patch_serving_chat_for_nemo_gym + patch_serving_chat_for_nemo_gym() + @property def lora_as_adapter(self) -> bool: return ( From 1bb9889d200df3fba2ab3f4c121b150c251a3c42 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sun, 29 Mar 2026 18:44:26 -0700 Subject: [PATCH 13/19] serverside patch and stuff Signed-off-by: cmunley1 --- verl/experimental/nemo_gym/agent_loop.py | 167 +++++---------------- verl/experimental/nemo_gym/server_patch.py | 112 +++++++++----- 2 files changed, 114 insertions(+), 165 deletions(-) diff --git a/verl/experimental/nemo_gym/agent_loop.py b/verl/experimental/nemo_gym/agent_loop.py index 1111a602e7b..8055f7aa10a 100644 --- a/verl/experimental/nemo_gym/agent_loop.py +++ b/verl/experimental/nemo_gym/agent_loop.py @@ -14,10 +14,10 @@ from __future__ import annotations import asyncio -import os import socket import sys import threading +from collections import defaultdict from typing import Optional import ray @@ -71,8 +71,6 @@ async def _init_nemo_gym(self) -> None: nemo_gym_root = nemo_gym_cfg.nemo_gym_root if nemo_gym_cfg else None - # Insert nemo_gym_root into sys.path so Ray remote actors (which don't - # inherit PYTHONPATH from the driver) can import nemo-gym from the shared fs. if nemo_gym_root and str(nemo_gym_root) not in sys.path: sys.path.insert(0, str(nemo_gym_root)) @@ -98,7 +96,6 @@ async def _init_nemo_gym(self) -> None: ) vllm_model_cfg["uses_reasoning_parser"] = uses_reasoning_parser - # Disable thinking if no reasoning parser if not uses_reasoning_parser: vllm_model_cfg.setdefault("extra_body", {}).setdefault("chat_template_kwargs", {})["enable_thinking"] = ( False @@ -128,27 +125,6 @@ async def _init_nemo_gym(self) -> None: initial_global_cfg[HEAD_SERVER_KEY_NAME] = {"host": "0.0.0.0", "port": head_port} self._head_server_config = BaseServerConfig(host=node_ip, port=head_port) - # Auto-detect agent ref. maybe dangerous for multi-environment - # TODO test multienv - self._default_agent_ref = None - config_paths = initial_global_cfg.get("config_paths", []) - for config_path in config_paths if isinstance(config_paths, list) else []: - try: - from omegaconf import OmegaConf - - yaml_cfg = OmegaConf.load(config_path) - for key in yaml_cfg: - entry = yaml_cfg[key] - if isinstance(entry, dict) and "responses_api_agents" in entry: - self._default_agent_ref = {"type": "responses_api_agents", "name": key} - print(f"[NemoGymAgentLoopManager] Detected agent: {key}") - break - except Exception: - pass - if self._default_agent_ref: - break - - # start nemo gym servers self._rh = RunHelper() self._rh.start( global_config_dict_parser_config=GlobalConfigDictParserConfig( @@ -160,13 +136,10 @@ async def _init_nemo_gym(self) -> None: self._rch = RolloutCollectionHelper() - # AgentLoopManager stores model_config as raw DictConfig; convert here - # since AgentLoopWorker which normally does this isn't used. from verl.utils.config import omega_conf_to_dataclass self._tokenizer = omega_conf_to_dataclass(self.model_config).tokenizer - # asyncio.run() was recreating loop on each step and erroring.. so make a single persistent loop self._rollout_loop = asyncio.new_event_loop() self._rollout_thread = threading.Thread( target=self._rollout_loop.run_forever, @@ -189,25 +162,21 @@ async def _async_generate_sequences(self, prompts: DataProto) -> DataProto: prompts, self.rollout_config, validate=validate, - default_agent_ref=self._default_agent_ref, ) - # run rollout collection nemo_gym_result_iterator = self._rch.run_examples( examples=nemo_gym_examples, head_server_config=self._head_server_config, ) - # collect results rowidxs, raw_results = [], [] for task in nemo_gym_result_iterator: nemo_gym_row, nemo_gym_result = await task try: result = _postprocess_nemo_gym_result(nemo_gym_result, self._tokenizer) except ValueError: - # Context length exceeded. return a dummy result (1 token, 0 reward, 0 logprob) - # TODO: should we fail here instead or what? i dont like dummy result. what does nemo rl do? result = _empty_result(nemo_gym_row, self._tokenizer) + result["env"] = nemo_gym_row["agent_ref"]["name"] rowidxs.append(nemo_gym_row["_rowidx"]) raw_results.append(result) @@ -215,14 +184,21 @@ async def _async_generate_sequences(self, prompts: DataProto) -> DataProto: for rowidx, result in zip(rowidxs, raw_results, strict=False): results[rowidx] = result - # pad to batch max instead of setting data.max_prompt_length - # TODO: review what we are padding here, is this dangerous or wasteful? prompt_lens = [sum(len(m["token_ids"]) for m in r["input_message_log"]) for r in results] response_lens = [ sum(len(m["token_ids"]) for m in r["message_log"][len(r["input_message_log"]) :]) for r in results ] prompt_length = max(prompt_lens) if prompt_lens else self.rollout_config.prompt_length + _vllm_kwargs = (getattr(self.rollout_config, "engine_kwargs", {}) or {}).get("vllm", {}) or {} + max_model_len = ( + getattr(self.rollout_config, "max_model_len", None) + or _vllm_kwargs.get("max-model-len") + or _vllm_kwargs.get("max_model_len") + ) + response_budget = (int(max_model_len) - prompt_length) if max_model_len else None response_length = max(response_lens) if response_lens else self.rollout_config.response_length + if response_budget: + response_length = min(response_length, response_budget) internal_outputs = [ _nemo_gym_result_to_verl( @@ -241,11 +217,19 @@ async def _async_generate_sequences(self, prompts: DataProto) -> DataProto: validate=validate, ) output.meta_info["global_steps"] = global_steps - # ray_trainer.py merges this into its own timing_raw at line ~1362; empty is fine + rollout_metrics = _compute_rollout_metrics(results, getattr(self.rollout_config, "max_model_len", None)) output.meta_info["timing"] = {} - output.meta_info["rollout_metrics"] = _compute_rollout_metrics( - results, getattr(self.rollout_config, "max_model_len", None) - ) + output.meta_info["rollout_metrics"] = rollout_metrics + + env_metrics = {k: v for k, v in rollout_metrics.items() if k.startswith("env/")} + if env_metrics: + try: + import wandb + if wandb.run is not None: + wandb.log({**env_metrics, "train/global_step": global_steps}, step=global_steps) + except Exception: + pass + return output @@ -253,7 +237,6 @@ def _build_nemo_gym_examples( prompts: DataProto, rollout_config, validate: bool = False, - default_agent_ref=None, ) -> list[dict]: cfg = rollout_config temperature = cfg.val_kwargs.temperature if validate else cfg.temperature @@ -264,14 +247,10 @@ def _build_nemo_gym_examples( for i in range(len(prompts)): messages = list(non_tensor["raw_prompt"][i]) - if "agent_ref" in non_tensor: - agent_ref = non_tensor["agent_ref"][i] - elif default_agent_ref is not None: - agent_ref = default_agent_ref - else: - agent_ref = {"type": "responses_api_agents", "name": "math_with_judge_simple_agent"} + if "agent_ref" not in non_tensor: + raise ValueError(f"dataset row {i} is missing agent_ref") + agent_ref = non_tensor["agent_ref"][i] - # Build responses_create_params rcp = {} if "extra_env_info" in non_tensor and "_rcp_extra" in (non_tensor["extra_env_info"][i] or {}): rcp.update(non_tensor["extra_env_info"][i]["_rcp_extra"]) @@ -294,61 +273,6 @@ def _build_nemo_gym_examples( return examples -def _replace_prefix_tokens( - tokenizer, - model_prefix_token_ids: list[int], - template_prefix_token_ids: list[int], - template_token_ids: list[int], -) -> list[int]: - """Fix chat-template re-tokenization differences across multi-turn calls. - - Find the first eos token that appears at or after the first - divergence point between model_prefix and template. That eos marks the - end of the re-tokenized assistant turn. Everything from that eos onward - in template_token_ids (eos + tool-result + gen-prompt) becomes the new - observation suffix, and the model's original prefix tokens are kept intact. - - TODO: review why original was not sufficient (re-tokenized turn can be shorter or longer) - TODO: verl avoids this entirely by appending token ids each turn without re-tokenizing - """ - if not model_prefix_token_ids: - return template_token_ids - raw_eos = tokenizer.eos_token_id - if raw_eos is None: - return template_token_ids - # handle case where eos_token_id is list (ideally there is just 1 lol) - eos_ids: set[int] = set(raw_eos) if isinstance(raw_eos, list) else {raw_eos} - - model_cut = len(model_prefix_token_ids) - if model_prefix_token_ids[-1] in eos_ids: - model_cut -= 1 - - # Find first position where the two sequences differ - first_diff = next( - (i for i, (a, b) in enumerate(zip(model_prefix_token_ids, template_token_ids, strict=False)) if a != b), - min(len(model_prefix_token_ids), len(template_token_ids)), - ) - - # look ahead from first_diff: find the eos that ends the re-tokenized assistant turn. - # works even when the re-tokenized turn is shorter than the original. - cut = -1 - for pos in range(first_diff, len(template_token_ids)): - if template_token_ids[pos] in eos_ids: - cut = pos - break - - # look back up to first_diff (for uncommon longer retokenized case) - if cut < 0: - for pos in reversed(range(first_diff)): - if template_token_ids[pos] in eos_ids: - cut = pos - break - - if cut < 0: - return template_token_ids - return model_prefix_token_ids[:model_cut] + template_token_ids[cut:] - - def _postprocess_nemo_gym_result(nemo_gym_result: dict, tokenizer) -> dict: message_log = [] seen_token_ids: list[int] = [] @@ -359,28 +283,9 @@ def _postprocess_nemo_gym_result(nemo_gym_result: dict, tokenizer) -> dict: prompt_ids = item["prompt_token_ids"] - # If token IDs are non-contiguous apply _replace_prefix_tokens fix to restore monotonic token IDs. - # TODO use verl way - if seen_token_ids and seen_token_ids != prompt_ids[: len(seen_token_ids)]: - prompt_ids = _replace_prefix_tokens( - tokenizer, - model_prefix_token_ids=seen_token_ids, - template_prefix_token_ids=seen_token_ids, - template_token_ids=prompt_ids, - ) - - if seen_token_ids != prompt_ids[: len(seen_token_ids)]: - diverge_at = next( - (i for i, (a, b) in enumerate(zip(seen_token_ids, prompt_ids, strict=False)) if a != b), - len(seen_token_ids), - ) - raise AssertionError( - f"Non-contiguous token IDs after replace_prefix fix. " - f"seen_len={len(seen_token_ids)} prompt_len={len(prompt_ids)} " - f"first_diverge={diverge_at} " - f"seen[{diverge_at - 2}:{diverge_at + 3}]={seen_token_ids[max(0, diverge_at - 2) : diverge_at + 3]} " - f"prompt[{diverge_at - 2}:{diverge_at + 3}]={prompt_ids[max(0, diverge_at - 2) : diverge_at + 3]}" - ) + assert ( + seen_token_ids == prompt_ids[: len(seen_token_ids)] + ), f"Non-contiguous token IDs (server_patch active?). seen={len(seen_token_ids)} prompt={len(prompt_ids)}" message_log.append( { @@ -409,7 +314,7 @@ def _postprocess_nemo_gym_result(nemo_gym_result: dict, tokenizer) -> dict: if not message_log: raise ValueError( "nemo-gym returned a result with no generation data. " - "The prompt (ie full context in multi step/turn) probably exceeds vLLM's max_model_len." + "The prompt may exceed vLLM's max_model_len." ) return { @@ -561,7 +466,7 @@ def _mean(vals): } ) - return { + metrics = { "turns_per_sample/mean": _mean([s["turns"] for s in stats]), "total_tokens_per_sample/mean": _mean([s["total"] for s in stats]), "gen_tokens_per_sample/mean": _mean([s["asst"] for s in stats]), @@ -570,3 +475,13 @@ def _mean(vals): "natural_termination_rate": sum(not s["hit_max"] for s in stats) / batch_size, "truncation_rate": sum(s["hit_max"] for s in stats) / batch_size, } + + # per environment metrics + env_rewards: dict = defaultdict(list) + for r, s in zip(results, stats, strict=True): + env_rewards[r.get("env", "unknown")].append(s["reward"]) + if len(env_rewards) > 1: + for env, rewards in env_rewards.items(): + metrics[f"env/{env}/reward_mean"] = _mean(rewards) + + return metrics diff --git a/verl/experimental/nemo_gym/server_patch.py b/verl/experimental/nemo_gym/server_patch.py index ea86d5cfd75..cbc0123ce3d 100644 --- a/verl/experimental/nemo_gym/server_patch.py +++ b/verl/experimental/nemo_gym/server_patch.py @@ -16,51 +16,31 @@ logger = logging.getLogger(__name__) + def _replace_prefix_tokens(model_prefix, template_prefix, template_ids, tok): - eos = tok.eos_token_id - if eos is None or not model_prefix: + # matches nemo-rl's implementation + if not model_prefix: return template_ids - eos_set = set(eos) if isinstance(eos, list) else {eos} + eos = tok.eos_token_id + assert eos is not None, "tokenizer must have eos_token_id" cut_model = len(model_prefix) - if model_prefix[-1] in eos_set: + if model_prefix[-1] == eos: cut_model -= 1 - if len(template_ids) <= len(template_prefix): - return template_ids + assert len(template_ids) > len(template_prefix), ( + f"non-monotonically increasing trajectory: " + f"template_ids={len(template_ids)} template_prefix={len(template_prefix)}" + ) cut = -1 for pos in reversed(range(len(template_prefix))): - if template_ids[pos] in eos_set: + if template_ids[pos] == eos: cut = pos break - if cut < 0: - return template_ids + assert cut >= 0, "no EOS token found in chat-templated messages" return model_prefix[:cut_model] + template_ids[cut:] -def patch_serving_chat_for_nemo_gym() -> None: - _serving_chat_cls = None - for _mod in ( - "vllm.entrypoints.openai.chat_completion.serving", - "vllm.entrypoints.openai.chat_completion", - "vllm.entrypoints.openai.api_server", - "vllm.entrypoints.openai.serving_chat", - ): - try: - import importlib - m = importlib.import_module(_mod) - if hasattr(m, "OpenAIServingChat"): - _serving_chat_cls = m.OpenAIServingChat - break - except ImportError: - continue - - if _serving_chat_cls is None: - logger.warning("[nemo-gym] could not find OpenAIServingChat; skipping retokenization patch.") - return - - OpenAIServingChat = _serving_chat_cls - _original_preprocess_chat = OpenAIServingChat._preprocess_chat - - async def _patched_preprocess_chat( +def _make_patched_preprocess_chat(original): + async def _patched( self, request, messages, default_template, default_template_content_format, default_template_kwargs, tool_dicts=None, tool_parser=None, @@ -75,7 +55,7 @@ async def _patched_preprocess_chat( required_prefix = list(msg.prompt_token_ids) + list(msg.generation_token_ids) break - res = await _original_preprocess_chat( + res = await original( self, request, messages, default_template, default_template_content_format, default_template_kwargs, tool_dicts=tool_dicts, tool_parser=tool_parser, @@ -85,15 +65,69 @@ async def _patched_preprocess_chat( return res try: - tok = self.renderer.get_tokenizer() # avoid concurrent tokenizer access - else already borrowed error w/ hermes tool parser + # call _preprocess_chat on messages up to last assistant turn (no gen prompt) + # to get template_prefix_ids for _replace_prefix_tokens + last_asst = next( + (i for i in reversed(range(len(messages))) + if (messages[i].get("role") if isinstance(messages[i], dict) + else getattr(messages[i], "role", None)) == "assistant"), + None, + ) + prefix_msgs = messages[:last_asst + 1] if last_asst is not None else messages + prefix_res = await original( + self, request, prefix_msgs, + default_template, default_template_content_format, + {**(default_template_kwargs or {}), "add_generation_prompt": False}, + tool_dicts=tool_dicts, tool_parser=tool_parser, + ) + template_prefix_ids = prefix_res[1][0]["prompt_token_ids"] + + tok = self.renderer.get_tokenizer() engine_prompt = res[1][0] engine_prompt["prompt_token_ids"] = _replace_prefix_tokens( - required_prefix, required_prefix, + required_prefix, template_prefix_ids, engine_prompt["prompt_token_ids"], tok, ) except Exception as e: logger.warning(f"[nemo-gym] retokenization patch failed, skipping: {e}") return res - OpenAIServingChat._preprocess_chat = _patched_preprocess_chat - logger.info("[nemo-gym] applied retokenization patch to OpenAIServingChat.") + return _patched + + +def patch_serving_chat_for_nemo_gym() -> None: + import importlib + + targets = { + "OpenAIServingChat": ( + "vllm.entrypoints.openai.chat_completion.serving", + "vllm.entrypoints.openai.chat_completion", + "vllm.entrypoints.openai.api_server", + "vllm.entrypoints.openai.serving_chat", + ), + "OpenAIServingTokenization": ( + "vllm.entrypoints.openai.api_server", + "vllm.entrypoints.serve.tokenize.serving", + ), + } + + patched_any = False + for cls_name, mods in targets.items(): + cls = None + for mod in mods: + try: + m = importlib.import_module(mod) + if hasattr(m, cls_name): + cls = getattr(m, cls_name) + break + except ImportError: + continue + if cls is None: + logger.warning(f"[nemo-gym] could not find {cls_name}; skipping.") + continue + cls._preprocess_chat = _make_patched_preprocess_chat(cls._preprocess_chat) + logger.warning(f"[nemo-gym] applied retokenization patch to {cls_name}.") + patched_any = True + + if not patched_any: + logger.warning("[nemo-gym] retokenization patch not applied to any serving class.") From 613295b8982490533884d703fd16979688f8230b Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sun, 29 Mar 2026 18:47:43 -0700 Subject: [PATCH 14/19] docs Signed-off-by: cmunley1 --- docs/examples/nemo_gym.rst | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/docs/examples/nemo_gym.rst b/docs/examples/nemo_gym.rst index bb72da41f00..44692796591 100644 --- a/docs/examples/nemo_gym.rst +++ b/docs/examples/nemo_gym.rst @@ -8,18 +8,20 @@ running NeMo Gym environments with verl using a custom agent loop manager. Overview -------- -The integration adds two components to ``verl/experimental/nemo_gym/``: +The integration adds three components to ``verl/experimental/nemo_gym/``: -- ``agent_loop.py`` — ``NemoGymAgentLoopManager``: offloads multi-turn rollouts - to NeMo Gym, handles retokenization correction across turns, and formats output. - The retokenization logic may change shortly to follow verl approach. -- ``dataset.py`` — ``NemoGymJSONLDataset``: loads NeMo Gym datasets +- ``agent_loop.py`` — ``NemoGymAgentLoopManager``: drives multi-turn rollouts + via NeMo Gym and formats results into verl's DataProto format. +- ``dataset.py`` — ``NemoGymJSONLDataset``: loads NeMo Gym JSONL datasets including messages, tools, agent refs, and metadata into verl format. +- ``server_patch.py`` — patches vLLM's ``OpenAIServingChat`` and + ``OpenAIServingTokenization`` to fix retokenization across multi-turn calls, + matching NeMo RL's approach. Requirements ------------ -- A NeMo Gym local clone (``gym-ref``) with the environment you want to train on. TODO finalize submodule decision +- A NeMo Gym clone with the environment you want to train on. - ``pip install -e /path/to/gym-ref`` installed into the container at job start. Quick Start @@ -57,15 +59,7 @@ The ``nemo_gym`` block in ``AgentLoopConfig`` accepts: nemo_gym: nemo_gym_root: /path/to/gym-ref uses_reasoning_parser: false - initial_global_config_dict: - config_paths: - - /path/to/env.yaml + config_paths: + - /path/to/env.yaml -Tool Calling ------------- - -For environments that use tool calling (e.g. workplace assistant), use a tool parser, for example:: - - '+actor_rollout_ref.rollout.engine_kwargs.vllm.enable-auto-tool-choice=true' - '+actor_rollout_ref.rollout.engine_kwargs.vllm.tool-call-parser=hermes' - '+actor_rollout_ref.rollout.engine_kwargs.vllm.max-model-len=32768' +For environments that use tool calling (e.g. workplace assistant), use a tool parser. For reasoning models, use a reasoning parser. \ No newline at end of file From a68f7a3bd9b8fc71d3a86a34e979a5c80dcfeddd Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sun, 29 Mar 2026 22:20:22 -0700 Subject: [PATCH 15/19] lint and log single env metrics too Signed-off-by: cmunley1 --- verl/experimental/nemo_gym/__init__.py | 13 ++++++ verl/experimental/nemo_gym/agent_loop.py | 19 ++++---- verl/experimental/nemo_gym/server_patch.py | 51 +++++++++++++++------- 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/verl/experimental/nemo_gym/__init__.py b/verl/experimental/nemo_gym/__init__.py index e69de29bb2d..1ce90c5eb35 100644 --- a/verl/experimental/nemo_gym/__init__.py +++ b/verl/experimental/nemo_gym/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/experimental/nemo_gym/agent_loop.py b/verl/experimental/nemo_gym/agent_loop.py index 8055f7aa10a..3565adca894 100644 --- a/verl/experimental/nemo_gym/agent_loop.py +++ b/verl/experimental/nemo_gym/agent_loop.py @@ -81,9 +81,7 @@ async def _init_nemo_gym(self) -> None: from nemo_gym.rollout_collection import RolloutCollectionHelper from nemo_gym.server_utils import HEAD_SERVER_KEY_NAME, BaseServerConfig except ModuleNotFoundError as e: - raise ImportError( - "nemo-gym not found. Install it with: pip install -e /path/to/gym-ref" - ) from e + raise ImportError("nemo-gym not found. Install it with: pip install -e /path/to/gym-ref") from e config_paths = list(nemo_gym_cfg.config_paths or []) if nemo_gym_cfg else [] initial_global_cfg = {"config_paths": config_paths} if config_paths else {} @@ -225,6 +223,7 @@ async def _async_generate_sequences(self, prompts: DataProto) -> DataProto: if env_metrics: try: import wandb + if wandb.run is not None: wandb.log({**env_metrics, "train/global_step": global_steps}, step=global_steps) except Exception: @@ -283,9 +282,9 @@ def _postprocess_nemo_gym_result(nemo_gym_result: dict, tokenizer) -> dict: prompt_ids = item["prompt_token_ids"] - assert ( - seen_token_ids == prompt_ids[: len(seen_token_ids)] - ), f"Non-contiguous token IDs (server_patch active?). seen={len(seen_token_ids)} prompt={len(prompt_ids)}" + assert seen_token_ids == prompt_ids[: len(seen_token_ids)], ( + f"Non-contiguous token IDs (server_patch active?). seen={len(seen_token_ids)} prompt={len(prompt_ids)}" + ) message_log.append( { @@ -313,8 +312,7 @@ def _postprocess_nemo_gym_result(nemo_gym_result: dict, tokenizer) -> dict: if not message_log: raise ValueError( - "nemo-gym returned a result with no generation data. " - "The prompt may exceed vLLM's max_model_len." + "nemo-gym returned a result with no generation data. The prompt may exceed vLLM's max_model_len." ) return { @@ -480,8 +478,7 @@ def _mean(vals): env_rewards: dict = defaultdict(list) for r, s in zip(results, stats, strict=True): env_rewards[r.get("env", "unknown")].append(s["reward"]) - if len(env_rewards) > 1: - for env, rewards in env_rewards.items(): - metrics[f"env/{env}/reward_mean"] = _mean(rewards) + for env, rewards in env_rewards.items(): + metrics[f"env/{env}/reward_mean"] = _mean(rewards) return metrics diff --git a/verl/experimental/nemo_gym/server_patch.py b/verl/experimental/nemo_gym/server_patch.py index cbc0123ce3d..aa56c60029f 100644 --- a/verl/experimental/nemo_gym/server_patch.py +++ b/verl/experimental/nemo_gym/server_patch.py @@ -41,9 +41,14 @@ def _replace_prefix_tokens(model_prefix, template_prefix, template_ids, tok): def _make_patched_preprocess_chat(original): async def _patched( - self, request, messages, - default_template, default_template_content_format, default_template_kwargs, - tool_dicts=None, tool_parser=None, + self, + request, + messages, + default_template, + default_template_content_format, + default_template_kwargs, + tool_dicts=None, + tool_parser=None, ): required_prefix = getattr(request, "required_prefix_token_ids", None) if required_prefix is None: @@ -56,9 +61,14 @@ async def _patched( break res = await original( - self, request, messages, - default_template, default_template_content_format, default_template_kwargs, - tool_dicts=tool_dicts, tool_parser=tool_parser, + self, + request, + messages, + default_template, + default_template_content_format, + default_template_kwargs, + tool_dicts=tool_dicts, + tool_parser=tool_parser, ) if required_prefix is None: @@ -68,25 +78,36 @@ async def _patched( # call _preprocess_chat on messages up to last assistant turn (no gen prompt) # to get template_prefix_ids for _replace_prefix_tokens last_asst = next( - (i for i in reversed(range(len(messages))) - if (messages[i].get("role") if isinstance(messages[i], dict) - else getattr(messages[i], "role", None)) == "assistant"), + ( + i + for i in reversed(range(len(messages))) + if ( + messages[i].get("role") if isinstance(messages[i], dict) else getattr(messages[i], "role", None) + ) + == "assistant" + ), None, ) - prefix_msgs = messages[:last_asst + 1] if last_asst is not None else messages + prefix_msgs = messages[: last_asst + 1] if last_asst is not None else messages prefix_res = await original( - self, request, prefix_msgs, - default_template, default_template_content_format, + self, + request, + prefix_msgs, + default_template, + default_template_content_format, {**(default_template_kwargs or {}), "add_generation_prompt": False}, - tool_dicts=tool_dicts, tool_parser=tool_parser, + tool_dicts=tool_dicts, + tool_parser=tool_parser, ) template_prefix_ids = prefix_res[1][0]["prompt_token_ids"] tok = self.renderer.get_tokenizer() engine_prompt = res[1][0] engine_prompt["prompt_token_ids"] = _replace_prefix_tokens( - required_prefix, template_prefix_ids, - engine_prompt["prompt_token_ids"], tok, + required_prefix, + template_prefix_ids, + engine_prompt["prompt_token_ids"], + tok, ) except Exception as e: logger.warning(f"[nemo-gym] retokenization patch failed, skipping: {e}") From 7fe549843eb861da075300ec98e96c9706847753 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Tue, 31 Mar 2026 11:47:07 -0700 Subject: [PATCH 16/19] updates Signed-off-by: cmunley1 --- verl/experimental/nemo_gym/agent_loop.py | 16 ++++++++++++---- verl/experimental/nemo_gym/server_patch.py | 15 +++++++++------ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/verl/experimental/nemo_gym/agent_loop.py b/verl/experimental/nemo_gym/agent_loop.py index 3565adca894..27df3087670 100644 --- a/verl/experimental/nemo_gym/agent_loop.py +++ b/verl/experimental/nemo_gym/agent_loop.py @@ -35,6 +35,9 @@ from verl.utils.model import compute_position_id_with_mask from verl.utils.ray_utils import auto_await +# _postprocess is an AgentLoopWorker method that stacks _InternalAgentLoopOutput +# into a DataProto batch. It accesses self.distillation_enabled and +# self.reward_loop_worker_handles, both of which we set in _init_nemo_gym. _postprocess = _AgentLoopWorker._postprocess @@ -137,6 +140,7 @@ async def _init_nemo_gym(self) -> None: from verl.utils.config import omega_conf_to_dataclass self._tokenizer = omega_conf_to_dataclass(self.model_config).tokenizer + self.distillation_enabled = False self._rollout_loop = asyncio.new_event_loop() self._rollout_thread = threading.Thread( @@ -179,8 +183,11 @@ async def _async_generate_sequences(self, prompts: DataProto) -> DataProto: raw_results.append(result) results = [None] * len(nemo_gym_examples) - for rowidx, result in zip(rowidxs, raw_results, strict=False): + for rowidx, result in zip(rowidxs, raw_results, strict=True): results[rowidx] = result + missing = [i for i, r in enumerate(results) if r is None] + if missing: + raise RuntimeError(f"nemo-gym did not return results for samples: {missing}") prompt_lens = [sum(len(m["token_ids"]) for m in r["input_message_log"]) for r in results] response_lens = [ @@ -282,9 +289,10 @@ def _postprocess_nemo_gym_result(nemo_gym_result: dict, tokenizer) -> dict: prompt_ids = item["prompt_token_ids"] - assert seen_token_ids == prompt_ids[: len(seen_token_ids)], ( - f"Non-contiguous token IDs (server_patch active?). seen={len(seen_token_ids)} prompt={len(prompt_ids)}" - ) + if seen_token_ids != prompt_ids[: len(seen_token_ids)]: + raise ValueError( + f"Non-contiguous token IDs (server_patch active?). seen={len(seen_token_ids)} prompt={len(prompt_ids)}" + ) message_log.append( { diff --git a/verl/experimental/nemo_gym/server_patch.py b/verl/experimental/nemo_gym/server_patch.py index aa56c60029f..355c80d3e28 100644 --- a/verl/experimental/nemo_gym/server_patch.py +++ b/verl/experimental/nemo_gym/server_patch.py @@ -22,20 +22,23 @@ def _replace_prefix_tokens(model_prefix, template_prefix, template_ids, tok): if not model_prefix: return template_ids eos = tok.eos_token_id - assert eos is not None, "tokenizer must have eos_token_id" + if eos is None: + raise ValueError("tokenizer must have eos_token_id") cut_model = len(model_prefix) if model_prefix[-1] == eos: cut_model -= 1 - assert len(template_ids) > len(template_prefix), ( - f"non-monotonically increasing trajectory: " - f"template_ids={len(template_ids)} template_prefix={len(template_prefix)}" - ) + if len(template_ids) <= len(template_prefix): + raise ValueError( + f"non-monotonically increasing trajectory: " + f"template_ids={len(template_ids)} template_prefix={len(template_prefix)}" + ) cut = -1 for pos in reversed(range(len(template_prefix))): if template_ids[pos] == eos: cut = pos break - assert cut >= 0, "no EOS token found in chat-templated messages" + if cut < 0: + raise ValueError("no EOS token found in chat-templated messages") return model_prefix[:cut_model] + template_ids[cut:] From 37f83c4a54fbf680b1d127c452d6ee7c92990b14 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 4 Apr 2026 02:04:36 -0700 Subject: [PATCH 17/19] move to recipe Signed-off-by: cmunley1 --- .gitmodules | 3 - 3rdparty/nemo_gym | 1 - docs/examples/nemo_gym.rst | 65 --- docs/index.rst | 1 - submit_math.sh | 171 ------ submit_workplace.sh | 174 ------- verl/experimental/nemo_gym/__init__.py | 13 - verl/experimental/nemo_gym/agent_loop.py | 492 ------------------ verl/experimental/nemo_gym/config.env.example | 7 - verl/experimental/nemo_gym/dataset.py | 84 --- verl/experimental/nemo_gym/server_patch.py | 157 ------ verl/workers/config/rollout.py | 8 - 12 files changed, 1176 deletions(-) delete mode 160000 3rdparty/nemo_gym delete mode 100644 docs/examples/nemo_gym.rst delete mode 100755 submit_math.sh delete mode 100644 submit_workplace.sh delete mode 100644 verl/experimental/nemo_gym/__init__.py delete mode 100644 verl/experimental/nemo_gym/agent_loop.py delete mode 100644 verl/experimental/nemo_gym/config.env.example delete mode 100644 verl/experimental/nemo_gym/dataset.py delete mode 100644 verl/experimental/nemo_gym/server_patch.py diff --git a/.gitmodules b/.gitmodules index de9f6e53919..d5dd7a6aa57 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ [submodule "recipe"] path = recipe url = https://github.com/verl-project/verl-recipe.git -[submodule "3rdparty/nemo_gym"] - path = 3rdparty/nemo_gym - url = https://github.com/NVIDIA-NeMo/Gym diff --git a/3rdparty/nemo_gym b/3rdparty/nemo_gym deleted file mode 160000 index f1399809b86..00000000000 --- a/3rdparty/nemo_gym +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f1399809b86c8d3f669bf3cf5c24efabec62dfad diff --git a/docs/examples/nemo_gym.rst b/docs/examples/nemo_gym.rst deleted file mode 100644 index 44692796591..00000000000 --- a/docs/examples/nemo_gym.rst +++ /dev/null @@ -1,65 +0,0 @@ -NVIDIA NeMo Gym Integration -================================== - -`NVIDIA NeMo Gym `_ (`docs `_) -is an RL environment framework for scalable, multi-environment, agentic RL. This integration enables -running NeMo Gym environments with verl using a custom agent loop manager. - -Overview --------- - -The integration adds three components to ``verl/experimental/nemo_gym/``: - -- ``agent_loop.py`` — ``NemoGymAgentLoopManager``: drives multi-turn rollouts - via NeMo Gym and formats results into verl's DataProto format. -- ``dataset.py`` — ``NemoGymJSONLDataset``: loads NeMo Gym JSONL datasets - including messages, tools, agent refs, and metadata into verl format. -- ``server_patch.py`` — patches vLLM's ``OpenAIServingChat`` and - ``OpenAIServingTokenization`` to fix retokenization across multi-turn calls, - matching NeMo RL's approach. - -Requirements ------------- - -- A NeMo Gym clone with the environment you want to train on. -- ``pip install -e /path/to/gym-ref`` installed into the container at job start. - -Quick Start ------------ - -1. **Install NeMo Gym** in your container startup script:: - - pip install -e /path/to/gym-ref - -2. **Prepare training datasets** in NeMo Gym JSONL format. Each line should be a - JSON object with a ``responses_create_params`` field containing the initial - messages and any tools, plus an ``agent_ref`` pointing at your environment's - agent server. - -3. **Add these overrides** to your verl training command:: - - +data.custom_cls.path=verl/experimental/nemo_gym/dataset.py - +data.custom_cls.name=NemoGymJSONLDataset - +actor_rollout_ref.rollout.agent.agent_loop_manager_class=verl.experimental.nemo_gym.agent_loop.NemoGymAgentLoopManager - "+actor_rollout_ref.rollout.agent.nemo_gym.config_paths=[/path/to/env.yaml]" - +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root=/path/to/gym-ref - -See ``submit_workplace.sh`` and ``submit_math.sh`` for working examples. - -Configuration -------------- - -The ``nemo_gym`` block in ``AgentLoopConfig`` accepts: - -.. code-block:: yaml - - actor_rollout_ref: - rollout: - agent: - nemo_gym: - nemo_gym_root: /path/to/gym-ref - uses_reasoning_parser: false - config_paths: - - /path/to/env.yaml - -For environments that use tool calling (e.g. workplace assistant), use a tool parser. For reasoning models, use a reasoning parser. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 2bac7398448..381d3a6bad9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -63,7 +63,6 @@ verl is fast with: examples/gsm8k_example examples/multi_modal_example examples/skypilot_examples - examples/nemo_gym .. toctree:: :maxdepth: 1 diff --git a/submit_math.sh b/submit_math.sh deleted file mode 100755 index 874f3ef33ab..00000000000 --- a/submit_math.sh +++ /dev/null @@ -1,171 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=verl-nemogym-dapo-7b-math -#SBATCH --nodes=4 -#SBATCH --ntasks-per-node=1 -#SBATCH --partition=your_partition -#SBATCH --account=your_account -#SBATCH --time=4:00:00 -#SBATCH --gres=gpu:8 -#SBATCH --exclusive -#SBATCH --output=logs/slurm-%j.out -#SBATCH --error=logs/slurm-%j.err - -set -euo pipefail - -GPUS_PER_NODE=8 - -source "${SLURM_SUBMIT_DIR}/config.env" - -MODEL_PATH="${DATA_ROOT}/models/Qwen2.5-Math-7B" -TRAIN_FILE="${DATA_ROOT}/math_with_judge/dapo17k_bytedtsinghua_train_nrl.jsonl" -TEST_FILE="${DATA_ROOT}/math_with_judge/aime24_bytedtsinghua_validation_nrl.jsonl" -CKPTS_DIR="${RESULTS_ROOT}/DAPO-Qwen2.5-7b-MATH-megatron" - -CONTAINER="verlai/verl:vllm017.latest" -MOUNTS="/lustre:/lustre" - -mkdir -p "${CKPTS_DIR}" - -nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") -nodes_array=($nodes) -head_node=${nodes_array[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address | awk '{print $1}') - -RAY_PORT=6379 -ip_head="${head_node_ip}:${RAY_PORT}" -echo "Head node: ${head_node} (${head_node_ip})" - -SRUN_ARGS="--no-container-mount-home --container-image=${CONTAINER} --container-mounts=${MOUNTS} --container-workdir=${VERL_ROOT}" - -echo "Starting Ray head on ${head_node}..." -srun --nodes=1 --ntasks=1 -w "${head_node}" ${SRUN_ARGS} --container-name=ray-head \ - env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start --head \ - --node-ip-address="${head_node_ip}" \ - --port=${RAY_PORT} \ - --num-gpus="${GPUS_PER_NODE}" \ - --block & -sleep 10 - -worker_num=$((SLURM_JOB_NUM_NODES - 1)) -for ((i = 1; i <= worker_num; i++)); do - node_i=${nodes_array[$i]} - echo "Starting Ray worker ${i} on ${node_i}..." - srun --nodes=1 --ntasks=1 -w "${node_i}" ${SRUN_ARGS} \ - env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start \ - --address="${ip_head}" \ - --num-gpus="${GPUS_PER_NODE}" \ - --block & - sleep 5 -done - -CONTAINER_DIR="/raid/enroot/data/user-${UID}/pyxis_${SLURM_JOB_ID}_ray-head" -echo "Waiting for ray-head container at ${CONTAINER_DIR}..." -elapsed=0 -while [[ ! -d "${CONTAINER_DIR}" && ${elapsed} -lt 300 ]]; do - sleep 5 - elapsed=$((elapsed + 5)) -done -if [[ ! -d "${CONTAINER_DIR}" ]]; then - echo "ERROR: ray-head container never appeared after 300s" - exit 1 -fi -echo "Container ready. Waiting 90s for all Ray workers to connect..." -sleep 90 - -echo "Installing nemo-gym..." -srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ - --no-container-mount-home --container-mounts=${MOUNTS} \ - --container-name=ray-head \ - bash -c "touch ${NEMO_GYM_ROOT}/scripts/__init__.py && pip install -q uv && echo 'blinker==1.4' > /tmp/constraints.txt && pip install -q -e ${NEMO_GYM_ROOT} -c /tmp/constraints.txt" - -echo "Launching training on ${head_node}..." -PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ - --no-container-mount-home --container-mounts=${MOUNTS} \ - --container-workdir=${VERL_ROOT} --container-name=ray-head \ - env -u ROCR_VISIBLE_DEVICES \ - WANDB_API_KEY="${WANDB_API_KEY}" \ - HF_HOME="${HF_HOME}" \ - HF_HUB_CACHE="${HF_HOME}/hub" \ - RAY_ADDRESS="auto" \ - VLLM_USE_V1=1 \ - TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ - PYTHONPATH="${NEMO_GYM_ROOT}" \ - RAY_grpc_keepalive_time_ms=60000 \ - RAY_grpc_keepalive_timeout_ms=600000 \ - RAY_grpc_client_keepalive_time_ms=60000 \ - RAY_grpc_client_keepalive_timeout_ms=600000 \ - python3 -m verl.trainer.main_ppo \ - --config-path="${VERL_ROOT}/recipe/dapo/config" \ - --config-name=dapo_megatron_trainer.yaml \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - +data.custom_cls.path="${VERL_ROOT}/verl/experimental/nemo_gym/dataset.py" \ - +data.custom_cls.name=NemoGymJSONLDataset \ - data.truncation=left \ - data.max_prompt_length=2048 \ - data.max_response_length=8192 \ - data.train_batch_size=512 \ - actor_rollout_ref.rollout.n=16 \ - algorithm.adv_estimator=grpo \ - algorithm.use_kl_in_reward=False \ - algorithm.kl_ctrl.kl_coef=0.0 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.actor.kl_loss_coef=0.0 \ - actor_rollout_ref.actor.clip_ratio_low=0.2 \ - actor_rollout_ref.actor.clip_ratio_high=0.28 \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.megatron.param_offload=True \ - actor_rollout_ref.actor.megatron.optimizer_offload=True \ - actor_rollout_ref.actor.megatron.grad_offload=True \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.optim.clip_grad=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=token-mean \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=10240 \ - actor_rollout_ref.rollout.temperature=1.0 \ - actor_rollout_ref.rollout.top_p=1.0 \ - actor_rollout_ref.rollout.top_k=-1 \ - actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ - actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ - actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.ref.megatron.param_offload=True \ - reward_model.reward_manager=dapo \ - +reward_model.reward_kwargs.overlong_buffer_cfg.enable=True \ - +reward_model.reward_kwargs.overlong_buffer_cfg.len=4096 \ - +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=1.0 \ - +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ - +reward_model.reward_kwargs.max_resp_len=8192 \ - 'trainer.logger=["console","wandb"]' \ - trainer.project_name=${WANDB_USERNAME}-verl-nemogym-int \ - trainer.experiment_name=dapo-7b-nemogym \ - trainer.n_gpus_per_node=${GPUS_PER_NODE} \ - trainer.nnodes=${SLURM_JOB_NUM_NODES} \ - trainer.val_before_train=False \ - trainer.test_freq=10 \ - trainer.save_freq=10 \ - trainer.total_epochs=10 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto \ - trainer.log_val_generations=10 \ - +actor_rollout_ref.rollout.agent.agent_loop_manager_class='verl.experimental.nemo_gym.agent_loop.NemoGymAgentLoopManager' \ - "+actor_rollout_ref.rollout.agent.nemo_gym.config_paths=[${NEMO_GYM_ROOT}/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,${NEMO_GYM_ROOT}/resources_servers/math_with_judge/configs/math_with_judge.yaml]" \ - +actor_rollout_ref.rollout.agent.nemo_gym.uses_reasoning_parser=False \ - +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root="${NEMO_GYM_ROOT}" \ - 2>&1 diff --git a/submit_workplace.sh b/submit_workplace.sh deleted file mode 100644 index bcc3ca05df8..00000000000 --- a/submit_workplace.sh +++ /dev/null @@ -1,174 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=verl-nemogym-dapo-4b-workplace -#SBATCH --nodes=4 -#SBATCH --ntasks-per-node=1 -#SBATCH --partition=your_partition -#SBATCH --account=your_account -#SBATCH --time=4:00:00 -#SBATCH --gres=gpu:8 -#SBATCH --exclusive -#SBATCH --output=logs/slurm-%j.out -#SBATCH --error=logs/slurm-%j.err - -set -euo pipefail - -GPUS_PER_NODE=8 - -source "${SLURM_SUBMIT_DIR}/config.env" - -MODEL_PATH="${HF_HOME}/hub/models--Qwen--Qwen3-4B-Instruct-2507/snapshots/cdbee75f17c01a7cc42f958dc650907174af0554" -TRAIN_FILE="${DATA_ROOT}/workplace_assistant/train.jsonl" -TEST_FILE="${DATA_ROOT}/workplace_assistant/validation.jsonl" -CKPTS_DIR="${RESULTS_ROOT}/dapo-qwen3-4b-workplace" - -CONTAINER="verlai/verl:vllm017.latest" -MOUNTS="/lustre:/lustre" - -mkdir -p "${CKPTS_DIR}" - -nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") -nodes_array=($nodes) -head_node=${nodes_array[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address | awk '{print $1}') - -RAY_PORT=6379 -ip_head="${head_node_ip}:${RAY_PORT}" -echo "Head node: ${head_node} (${head_node_ip})" - -SRUN_ARGS="--no-container-mount-home --container-image=${CONTAINER} --container-mounts=${MOUNTS} --container-workdir=${VERL_ROOT}" - -echo "Starting Ray head on ${head_node}..." -srun --nodes=1 --ntasks=1 -w "${head_node}" ${SRUN_ARGS} --container-name=ray-head \ - env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start --head \ - --node-ip-address="${head_node_ip}" \ - --port=${RAY_PORT} \ - --num-gpus="${GPUS_PER_NODE}" \ - --block & -sleep 10 - -worker_num=$((SLURM_JOB_NUM_NODES - 1)) -for ((i = 1; i <= worker_num; i++)); do - node_i=${nodes_array[$i]} - echo "Starting Ray worker ${i} on ${node_i}..." - srun --nodes=1 --ntasks=1 -w "${node_i}" ${SRUN_ARGS} \ - env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start \ - --address="${ip_head}" \ - --num-gpus="${GPUS_PER_NODE}" \ - --block & - sleep 5 -done - -CONTAINER_DIR="/raid/enroot/data/user-${UID}/pyxis_${SLURM_JOB_ID}_ray-head" -echo "Waiting for ray-head container at ${CONTAINER_DIR}..." -elapsed=0 -while [[ ! -d "${CONTAINER_DIR}" && ${elapsed} -lt 300 ]]; do - sleep 5 - elapsed=$((elapsed + 5)) -done -if [[ ! -d "${CONTAINER_DIR}" ]]; then - echo "ERROR: ray-head container never appeared after 300s" - exit 1 -fi -echo "Container ready. Waiting 90s for all Ray workers to connect..." -sleep 90 - -echo "Installing nemo-gym..." -srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ - --no-container-mount-home --container-mounts=${MOUNTS} \ - --container-name=ray-head \ - bash -c "PYTHONPATH= touch ${NEMO_GYM_ROOT}/scripts/__init__.py && pip install -q uv && echo 'blinker==1.4' > /tmp/constraints.txt && pip install -q -e ${NEMO_GYM_ROOT} -c /tmp/constraints.txt" - -echo "Launching training on ${head_node}..." -PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ - --no-container-mount-home --container-mounts=${MOUNTS} \ - --container-workdir=${VERL_ROOT} --container-name=ray-head \ - env -u ROCR_VISIBLE_DEVICES \ - WANDB_API_KEY="${WANDB_API_KEY}" \ - HF_HOME="${HF_HOME}" \ - HF_HUB_CACHE="${HF_HOME}/hub" \ - RAY_ADDRESS="auto" \ - VLLM_USE_V1=1 \ - TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ - PYTHONPATH="${NEMO_GYM_ROOT}" \ - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ - RAY_grpc_keepalive_time_ms=60000 \ - RAY_grpc_keepalive_timeout_ms=600000 \ - RAY_grpc_client_keepalive_time_ms=60000 \ - RAY_grpc_client_keepalive_timeout_ms=600000 \ - python3 -m verl.trainer.main_ppo \ - --config-path="${VERL_ROOT}/recipe/dapo/config" \ - --config-name=dapo_megatron_trainer.yaml \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - +data.custom_cls.path="${VERL_ROOT}/verl/experimental/nemo_gym/dataset.py" \ - +data.custom_cls.name=NemoGymJSONLDataset \ - data.truncation=left \ - data.train_batch_size=512 \ - actor_rollout_ref.rollout.n=16 \ - algorithm.adv_estimator=grpo \ - algorithm.use_kl_in_reward=False \ - algorithm.kl_ctrl.kl_coef=0.0 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.actor.kl_loss_coef=0.0 \ - actor_rollout_ref.actor.clip_ratio_low=0.2 \ - actor_rollout_ref.actor.clip_ratio_high=0.28 \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.megatron.param_offload=True \ - actor_rollout_ref.actor.megatron.optimizer_offload=True \ - actor_rollout_ref.actor.megatron.grad_offload=True \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.optim.clip_grad=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=token-mean \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=10240 \ - actor_rollout_ref.rollout.temperature=1.0 \ - actor_rollout_ref.rollout.top_p=1.0 \ - actor_rollout_ref.rollout.top_k=-1 \ - actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ - actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ - actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.rollout.name=vllm \ - '+actor_rollout_ref.rollout.engine_kwargs.vllm.enable-auto-tool-choice=true' \ - '+actor_rollout_ref.rollout.engine_kwargs.vllm.tool-call-parser=hermes_patched' \ - "+actor_rollout_ref.rollout.engine_kwargs.vllm.tool-parser-plugin=${VERL_ROOT}/verl/experimental/tool_parsers/hermes_tool_parser_patched.py" \ - '+actor_rollout_ref.rollout.engine_kwargs.vllm.max-model-len=32768' \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.ref.megatron.param_offload=True \ - reward_model.reward_manager=dapo \ - +reward_model.reward_kwargs.overlong_buffer_cfg.enable=True \ - +reward_model.reward_kwargs.overlong_buffer_cfg.len=4096 \ - +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=1.0 \ - +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ - +reward_model.reward_kwargs.max_resp_len=8192 \ - 'trainer.logger=["console","wandb"]' \ - trainer.project_name=${WANDB_USERNAME}-verl-nemogym-int \ - trainer.experiment_name=dapo-qwen3-4b-workplace \ - trainer.n_gpus_per_node=${GPUS_PER_NODE} \ - trainer.nnodes=${SLURM_JOB_NUM_NODES} \ - trainer.val_before_train=False \ - trainer.test_freq=10 \ - trainer.save_freq=10 \ - trainer.total_epochs=10 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto \ - trainer.log_val_generations=10 \ - +actor_rollout_ref.rollout.agent.agent_loop_manager_class='verl.experimental.nemo_gym.agent_loop.NemoGymAgentLoopManager' \ - "+actor_rollout_ref.rollout.agent.nemo_gym.config_paths=[${NEMO_GYM_ROOT}/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,${NEMO_GYM_ROOT}/resources_servers/workplace_assistant/configs/workplace_assistant.yaml]" \ - +actor_rollout_ref.rollout.agent.nemo_gym.uses_reasoning_parser=False \ - +actor_rollout_ref.rollout.agent.nemo_gym.nemo_gym_root="${NEMO_GYM_ROOT}" \ - 2>&1 diff --git a/verl/experimental/nemo_gym/__init__.py b/verl/experimental/nemo_gym/__init__.py deleted file mode 100644 index 1ce90c5eb35..00000000000 --- a/verl/experimental/nemo_gym/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/experimental/nemo_gym/agent_loop.py b/verl/experimental/nemo_gym/agent_loop.py deleted file mode 100644 index 27df3087670..00000000000 --- a/verl/experimental/nemo_gym/agent_loop.py +++ /dev/null @@ -1,492 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import asyncio -import socket -import sys -import threading -from collections import defaultdict -from typing import Optional - -import ray -import torch - -from verl.experimental.agent_loop.agent_loop import ( - AgentLoopManager, - AgentLoopMetrics, - _InternalAgentLoopOutput, -) -from verl.experimental.agent_loop.agent_loop import ( - AgentLoopWorker as _AgentLoopWorker, -) -from verl.protocol import DataProto -from verl.utils.model import compute_position_id_with_mask -from verl.utils.ray_utils import auto_await - -# _postprocess is an AgentLoopWorker method that stacks _InternalAgentLoopOutput -# into a DataProto batch. It accesses self.distillation_enabled and -# self.reward_loop_worker_handles, both of which we set in _init_nemo_gym. -_postprocess = _AgentLoopWorker._postprocess - - -class NemoGymAgentLoopManager(AgentLoopManager): - @classmethod - @auto_await - async def create( - cls, - config, - worker_group=None, - rollout_resource_pool=None, - reward_loop_worker_handles=None, - teacher_model_manager=None, - ) -> NemoGymAgentLoopManager: - instance = cls( - config, - worker_group, - rollout_resource_pool, - teacher_model_manager, - reward_loop_worker_handles, - ) - await instance._initialize_llm_servers() - await instance._init_global_load_balancer() - await instance._apply_server_patch() - await instance._init_nemo_gym() - return instance - - async def _apply_server_patch(self) -> None: - futs = [handle.apply_nemo_gym_server_patch.remote() for handle in self.server_handles] - await asyncio.get_event_loop().run_in_executor(None, ray.get, futs) - - async def _init_nemo_gym(self) -> None: - nemo_gym_cfg = self.rollout_config.agent.nemo_gym - - nemo_gym_root = nemo_gym_cfg.nemo_gym_root if nemo_gym_cfg else None - - if nemo_gym_root and str(nemo_gym_root) not in sys.path: - sys.path.insert(0, str(nemo_gym_root)) - - from omegaconf import DictConfig - - try: - from nemo_gym.cli import GlobalConfigDictParserConfig, RunHelper - from nemo_gym.rollout_collection import RolloutCollectionHelper - from nemo_gym.server_utils import HEAD_SERVER_KEY_NAME, BaseServerConfig - except ModuleNotFoundError as e: - raise ImportError("nemo-gym not found. Install it with: pip install -e /path/to/gym-ref") from e - - config_paths = list(nemo_gym_cfg.config_paths or []) if nemo_gym_cfg else [] - initial_global_cfg = {"config_paths": config_paths} if config_paths else {} - - uses_reasoning_parser = nemo_gym_cfg.uses_reasoning_parser if nemo_gym_cfg else False - vllm_model_cfg = ( - initial_global_cfg.setdefault("policy_model", {}) - .setdefault("responses_api_models", {}) - .setdefault("vllm_model", {}) - ) - vllm_model_cfg["uses_reasoning_parser"] = uses_reasoning_parser - - if not uses_reasoning_parser: - vllm_model_cfg.setdefault("extra_body", {}).setdefault("chat_template_kwargs", {})["enable_thinking"] = ( - False - ) - - base_urls = [ - (addr if addr.startswith("http") else f"http://{addr}").rstrip("/") + "/v1" - for addr in self.server_addresses - ] - initial_global_cfg["policy_model_name"] = self.model_config.get("path", "") - initial_global_cfg["policy_api_key"] = "dummy_key" - initial_global_cfg["policy_base_url"] = base_urls - initial_global_cfg.setdefault("global_aiohttp_connector_limit_per_host", 16_384) - initial_global_cfg.setdefault("global_aiohttp_connector_limit", 65_536) - - ray_context = ray.get_runtime_context() - initial_global_cfg["ray_head_node_address"] = ray_context.gcs_address - - if nemo_gym_root: - initial_global_cfg.setdefault("uv_venv_dir", str(nemo_gym_root)) - initial_global_cfg.setdefault("skip_venv_if_present", True) - - node_ip = socket.gethostbyname(socket.gethostname()) - with socket.socket() as s: - s.bind(("", 0)) - head_port = s.getsockname()[1] - initial_global_cfg[HEAD_SERVER_KEY_NAME] = {"host": "0.0.0.0", "port": head_port} - self._head_server_config = BaseServerConfig(host=node_ip, port=head_port) - - self._rh = RunHelper() - self._rh.start( - global_config_dict_parser_config=GlobalConfigDictParserConfig( - dotenv_path=None, - initial_global_config_dict=DictConfig(initial_global_cfg), - skip_load_from_cli=True, - ) - ) - - self._rch = RolloutCollectionHelper() - - from verl.utils.config import omega_conf_to_dataclass - - self._tokenizer = omega_conf_to_dataclass(self.model_config).tokenizer - self.distillation_enabled = False - - self._rollout_loop = asyncio.new_event_loop() - self._rollout_thread = threading.Thread( - target=self._rollout_loop.run_forever, - daemon=True, - name="nemo-gym-rollout-loop", - ) - self._rollout_thread.start() - - print(f"NemoGymAgentLoopManager ready: {len(base_urls)} vLLM endpoints: {base_urls}") - - def generate_sequences(self, prompts: DataProto) -> DataProto: - future = asyncio.run_coroutine_threadsafe(self._async_generate_sequences(prompts), self._rollout_loop) - return future.result() - - async def _async_generate_sequences(self, prompts: DataProto) -> DataProto: - validate = prompts.meta_info.get("validate", False) - global_steps = prompts.meta_info.get("global_steps", -1) - - nemo_gym_examples = _build_nemo_gym_examples( - prompts, - self.rollout_config, - validate=validate, - ) - - nemo_gym_result_iterator = self._rch.run_examples( - examples=nemo_gym_examples, - head_server_config=self._head_server_config, - ) - - rowidxs, raw_results = [], [] - for task in nemo_gym_result_iterator: - nemo_gym_row, nemo_gym_result = await task - try: - result = _postprocess_nemo_gym_result(nemo_gym_result, self._tokenizer) - except ValueError: - result = _empty_result(nemo_gym_row, self._tokenizer) - result["env"] = nemo_gym_row["agent_ref"]["name"] - rowidxs.append(nemo_gym_row["_rowidx"]) - raw_results.append(result) - - results = [None] * len(nemo_gym_examples) - for rowidx, result in zip(rowidxs, raw_results, strict=True): - results[rowidx] = result - missing = [i for i, r in enumerate(results) if r is None] - if missing: - raise RuntimeError(f"nemo-gym did not return results for samples: {missing}") - - prompt_lens = [sum(len(m["token_ids"]) for m in r["input_message_log"]) for r in results] - response_lens = [ - sum(len(m["token_ids"]) for m in r["message_log"][len(r["input_message_log"]) :]) for r in results - ] - prompt_length = max(prompt_lens) if prompt_lens else self.rollout_config.prompt_length - _vllm_kwargs = (getattr(self.rollout_config, "engine_kwargs", {}) or {}).get("vllm", {}) or {} - max_model_len = ( - getattr(self.rollout_config, "max_model_len", None) - or _vllm_kwargs.get("max-model-len") - or _vllm_kwargs.get("max_model_len") - ) - response_budget = (int(max_model_len) - prompt_length) if max_model_len else None - response_length = max(response_lens) if response_lens else self.rollout_config.response_length - if response_budget: - response_length = min(response_length, response_budget) - - internal_outputs = [ - _nemo_gym_result_to_verl( - result=result, - tokenizer=self._tokenizer, - prompt_length=prompt_length, - response_length=response_length, - ) - for result in results - ] - - output = _postprocess( - self, - internal_outputs, - input_non_tensor_batch=prompts.non_tensor_batch, - validate=validate, - ) - output.meta_info["global_steps"] = global_steps - rollout_metrics = _compute_rollout_metrics(results, getattr(self.rollout_config, "max_model_len", None)) - output.meta_info["timing"] = {} - output.meta_info["rollout_metrics"] = rollout_metrics - - env_metrics = {k: v for k, v in rollout_metrics.items() if k.startswith("env/")} - if env_metrics: - try: - import wandb - - if wandb.run is not None: - wandb.log({**env_metrics, "train/global_step": global_steps}, step=global_steps) - except Exception: - pass - - return output - - -def _build_nemo_gym_examples( - prompts: DataProto, - rollout_config, - validate: bool = False, -) -> list[dict]: - cfg = rollout_config - temperature = cfg.val_kwargs.temperature if validate else cfg.temperature - top_p = cfg.val_kwargs.top_p if validate else cfg.top_p - - non_tensor = prompts.non_tensor_batch - examples = [] - for i in range(len(prompts)): - messages = list(non_tensor["raw_prompt"][i]) - - if "agent_ref" not in non_tensor: - raise ValueError(f"dataset row {i} is missing agent_ref") - agent_ref = non_tensor["agent_ref"][i] - - rcp = {} - if "extra_env_info" in non_tensor and "_rcp_extra" in (non_tensor["extra_env_info"][i] or {}): - rcp.update(non_tensor["extra_env_info"][i]["_rcp_extra"]) - rcp.update({"input": messages, "temperature": temperature, "top_p": top_p}) - - row = { - "responses_create_params": rcp, - "agent_ref": agent_ref, - "_rowidx": i, - } - - if "extra_env_info" in non_tensor: - env_info = non_tensor["extra_env_info"][i] - if isinstance(env_info, dict): - for k, v in env_info.items(): - if k not in row: - row[k] = v - - examples.append(row) - return examples - - -def _postprocess_nemo_gym_result(nemo_gym_result: dict, tokenizer) -> dict: - message_log = [] - seen_token_ids: list[int] = [] - - for item in nemo_gym_result["response"]["output"]: - if "generation_token_ids" not in item: - continue - - prompt_ids = item["prompt_token_ids"] - - if seen_token_ids != prompt_ids[: len(seen_token_ids)]: - raise ValueError( - f"Non-contiguous token IDs (server_patch active?). seen={len(seen_token_ids)} prompt={len(prompt_ids)}" - ) - - message_log.append( - { - "role": "user", - "content": "", - "token_ids": torch.tensor(prompt_ids[len(seen_token_ids) :]), - } - ) - message_log.append( - { - "role": "assistant", - "content": "", - "token_ids": torch.tensor(item["generation_token_ids"]), - "generation_logprobs": torch.tensor(item["generation_log_probs"]), - } - ) - - seen_token_ids.extend(message_log[-2]["token_ids"].tolist()) - seen_token_ids.extend(message_log[-1]["token_ids"].tolist()) - - item.pop("prompt_token_ids", None) - item["prompt_str"] = tokenizer.decode(prompt_ids) - item["generation_str"] = tokenizer.decode(item.pop("generation_token_ids")) - item.pop("generation_log_probs") - - if not message_log: - raise ValueError( - "nemo-gym returned a result with no generation data. The prompt may exceed vLLM's max_model_len." - ) - - return { - "message_log": message_log, - "input_message_log": message_log[:1], - "full_result": nemo_gym_result, - } - - -def _empty_result(nemo_gym_row: dict, tokenizer) -> dict: - """empty result for overlong samples - TODO: should we truncate or something else? what is best practice here?""" - - messages = nemo_gym_row.get("responses_create_params", {}).get("input", []) - raw_prompt = [{"role": m.get("role", "user"), "content": m.get("content", "")} for m in messages] - prompt_ids = tokenizer.apply_chat_template(raw_prompt, tokenize=True, add_generation_prompt=False)[-1:] - dummy_tok = torch.tensor(prompt_ids, dtype=torch.long) - return { - "message_log": [ - {"role": "user", "token_ids": dummy_tok}, - {"role": "assistant", "token_ids": dummy_tok, "generation_logprobs": torch.zeros(len(dummy_tok))}, - ], - "input_message_log": [{"role": "user", "token_ids": dummy_tok}], - "full_result": {"reward": 0.0}, - } - - -def _nemo_gym_result_to_verl( - result: dict, - tokenizer, - prompt_length: int, - response_length: int, -) -> _InternalAgentLoopOutput: - """Pack message_log into padded tensors for verl and mask non assistant messages""" - message_log = result["message_log"] - input_message_log = result["input_message_log"] - - prompt_ids_raw: list[int] = [] - for msg in input_message_log: - tids = msg["token_ids"] - prompt_ids_raw.extend(tids.tolist() if isinstance(tids, torch.Tensor) else tids) - - n_prompt_msgs = len(input_message_log) - response_ids_raw: list[int] = [] - response_mask_raw: list[int] = [] - response_logprobs_raw: list[float] = [] - - for msg in message_log[n_prompt_msgs:]: - tids = msg["token_ids"] - toks = tids.tolist() if isinstance(tids, torch.Tensor) else list(tids) - if msg["role"] == "assistant": - response_ids_raw.extend(toks) - response_mask_raw.extend([1] * len(toks)) - lp = msg.get("generation_logprobs") - response_logprobs_raw.extend( - lp.tolist() if isinstance(lp, torch.Tensor) else list(lp) if lp is not None else [0.0] * len(toks) - ) - else: - response_ids_raw.extend(toks) - response_mask_raw.extend([0] * len(toks)) - response_logprobs_raw.extend([0.0] * len(toks)) - - response_ids_raw = response_ids_raw[:response_length] - response_mask_raw = response_mask_raw[:response_length] - response_logprobs_raw = response_logprobs_raw[:response_length] - - tokenizer.padding_side = "left" - prompt_out = tokenizer.pad( - {"input_ids": prompt_ids_raw}, - padding="max_length", - max_length=prompt_length, - return_tensors="pt", - return_attention_mask=True, - ) - prompt_ids = prompt_out["input_ids"] - if prompt_ids.dim() == 1: - prompt_ids = prompt_ids.unsqueeze(0) - prompt_out["attention_mask"] = prompt_out["attention_mask"].unsqueeze(0) - - tokenizer.padding_side = "right" - response_out = tokenizer.pad( - {"input_ids": response_ids_raw}, - padding="max_length", - max_length=response_length, - return_tensors="pt", - return_attention_mask=True, - ) - response_ids = response_out["input_ids"] - response_attn = response_out["attention_mask"] - if response_ids.dim() == 1: - response_ids = response_ids.unsqueeze(0) - response_attn = response_attn.unsqueeze(0) - - pad = response_length - len(response_mask_raw) - response_mask = torch.tensor(response_mask_raw + [0] * pad, dtype=torch.long).unsqueeze(0) * response_attn - - pad = response_length - len(response_logprobs_raw) - response_logprobs = torch.tensor(response_logprobs_raw + [0.0] * pad, dtype=torch.float32).unsqueeze(0) - - attention_mask = torch.cat([prompt_out["attention_mask"], response_attn], dim=1) - input_ids = torch.cat([prompt_ids, response_ids], dim=1) - - position_ids = compute_position_id_with_mask(attention_mask) - - reward_score: Optional[float] = result.get("full_result", {}).get("reward", None) - num_turns = sum(1 for m in message_log if m["role"] == "assistant") - - return _InternalAgentLoopOutput( - prompt_ids=prompt_ids, - response_ids=response_ids, - input_ids=input_ids, - position_ids=position_ids, - response_mask=response_mask, - attention_mask=attention_mask, - response_logprobs=response_logprobs, - routed_experts=None, - multi_modal_inputs={}, - teacher_logprobs=None, - teacher_ids=None, - reward_score=reward_score, - num_turns=num_turns, - metrics=AgentLoopMetrics(), - extra_fields={}, - ) - - -def _compute_rollout_metrics(results: list[dict], max_model_len: Optional[int] = None) -> dict: - batch_size = len(results) - if batch_size == 0: - return {} - - def _mean(vals): - return sum(vals) / len(vals) if vals else 0.0 - - stats = [] - for r in results: - ml = r["message_log"] - total = sum(len(m["token_ids"]) for m in ml) - asst = sum(len(m["token_ids"]) for m in ml if m["role"] == "assistant") - turns = sum(1 for m in ml if m["role"] == "user") - hit_max = (max_model_len is not None) and (total >= max_model_len) - stats.append( - { - "reward": float(r.get("full_result", {}).get("reward", 0.0)), - "asst": asst, - "total": total, - "turns": turns, - "hit_max": hit_max, - } - ) - - metrics = { - "turns_per_sample/mean": _mean([s["turns"] for s in stats]), - "total_tokens_per_sample/mean": _mean([s["total"] for s in stats]), - "gen_tokens_per_sample/mean": _mean([s["asst"] for s in stats]), - "mean_gen_tokens_per_sample": _mean([s["asst"] for s in stats]), - "total_reward/mean": _mean([s["reward"] for s in stats]), - "natural_termination_rate": sum(not s["hit_max"] for s in stats) / batch_size, - "truncation_rate": sum(s["hit_max"] for s in stats) / batch_size, - } - - # per environment metrics - env_rewards: dict = defaultdict(list) - for r, s in zip(results, stats, strict=True): - env_rewards[r.get("env", "unknown")].append(s["reward"]) - for env, rewards in env_rewards.items(): - metrics[f"env/{env}/reward_mean"] = _mean(rewards) - - return metrics diff --git a/verl/experimental/nemo_gym/config.env.example b/verl/experimental/nemo_gym/config.env.example deleted file mode 100644 index 6798fd72409..00000000000 --- a/verl/experimental/nemo_gym/config.env.example +++ /dev/null @@ -1,7 +0,0 @@ -VERL_ROOT=/path/to/verl -NEMO_GYM_ROOT=/path/to/gym-ref -HF_HOME=/path/to/hf_home -RESULTS_ROOT=/path/to/results -DATA_ROOT=/path/to/data -WANDB_USERNAME=your_wandb_username -WANDB_API_KEY=your_key_here diff --git a/verl/experimental/nemo_gym/dataset.py b/verl/experimental/nemo_gym/dataset.py deleted file mode 100644 index ab29089f170..00000000000 --- a/verl/experimental/nemo_gym/dataset.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -from pathlib import Path - -import torch -from torch.utils.data import Dataset - - -class NemoGymJSONLDataset(Dataset): - def __init__( - self, - data_files: str | list[str], - tokenizer, - processor=None, - config=None, - **kwargs, - ): - if isinstance(data_files, str): - data_files = [data_files] - - self.tokenizer = tokenizer - self._rows: list[dict] = [] - - for path in data_files: - path = Path(path) - if not path.exists(): - raise FileNotFoundError(f"NemoGymJSONLDataset: file not found: {path}") - with open(path) as f: - for line in f: - line = line.strip() - if line: - self._rows.append(json.loads(line)) - - def __len__(self) -> int: - return len(self._rows) - - def __getitem__(self, idx: int) -> dict: - row = self._rows[idx] - - rcp = row.get("responses_create_params", {}) - messages = rcp.get("input", []) - raw_prompt = [{"role": m.get("role", "user"), "content": m.get("content", "")} for m in messages] - - agent_ref = row.get("agent_ref", None) - - skip_keys = {"responses_create_params", "agent_ref"} - extra_env_info = {k: v for k, v in row.items() if k not in skip_keys} - - # preserve per-task fields like `tools` to nemo-gym request - # input/temperature/top_p/top_k are overridden per training step - # parallel_tool_calls is dropped because vLLM throws 500 error - _rcp_skip = {"input", "temperature", "top_p", "top_k", "parallel_tool_calls"} - rcp_extra = {k: v for k, v in rcp.items() if k not in _rcp_skip} - if rcp_extra: - extra_env_info["_rcp_extra"] = rcp_extra - - out = { - "raw_prompt": raw_prompt, - # unused placeholder. ray_trainer.py calls len(batch.batch) which requires batch to be non-empty - "__nemo_gym_batch_size__": torch.zeros(1, dtype=torch.long), - } - if agent_ref is not None: - out["agent_ref"] = agent_ref - out["extra_env_info"] = extra_env_info - - return out - - @property - def collate_fn(self): - from verl.utils.dataset.rl_dataset import collate_fn - - return collate_fn diff --git a/verl/experimental/nemo_gym/server_patch.py b/verl/experimental/nemo_gym/server_patch.py deleted file mode 100644 index 355c80d3e28..00000000000 --- a/verl/experimental/nemo_gym/server_patch.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -logger = logging.getLogger(__name__) - - -def _replace_prefix_tokens(model_prefix, template_prefix, template_ids, tok): - # matches nemo-rl's implementation - if not model_prefix: - return template_ids - eos = tok.eos_token_id - if eos is None: - raise ValueError("tokenizer must have eos_token_id") - cut_model = len(model_prefix) - if model_prefix[-1] == eos: - cut_model -= 1 - if len(template_ids) <= len(template_prefix): - raise ValueError( - f"non-monotonically increasing trajectory: " - f"template_ids={len(template_ids)} template_prefix={len(template_prefix)}" - ) - cut = -1 - for pos in reversed(range(len(template_prefix))): - if template_ids[pos] == eos: - cut = pos - break - if cut < 0: - raise ValueError("no EOS token found in chat-templated messages") - return model_prefix[:cut_model] + template_ids[cut:] - - -def _make_patched_preprocess_chat(original): - async def _patched( - self, - request, - messages, - default_template, - default_template_content_format, - default_template_kwargs, - tool_dicts=None, - tool_parser=None, - ): - required_prefix = getattr(request, "required_prefix_token_ids", None) - if required_prefix is None: - for msg in reversed(messages): - if isinstance(msg, dict) and "prompt_token_ids" in msg: - required_prefix = list(msg["prompt_token_ids"]) + list(msg["generation_token_ids"]) - break - elif not isinstance(msg, dict) and getattr(msg, "prompt_token_ids", None): - required_prefix = list(msg.prompt_token_ids) + list(msg.generation_token_ids) - break - - res = await original( - self, - request, - messages, - default_template, - default_template_content_format, - default_template_kwargs, - tool_dicts=tool_dicts, - tool_parser=tool_parser, - ) - - if required_prefix is None: - return res - - try: - # call _preprocess_chat on messages up to last assistant turn (no gen prompt) - # to get template_prefix_ids for _replace_prefix_tokens - last_asst = next( - ( - i - for i in reversed(range(len(messages))) - if ( - messages[i].get("role") if isinstance(messages[i], dict) else getattr(messages[i], "role", None) - ) - == "assistant" - ), - None, - ) - prefix_msgs = messages[: last_asst + 1] if last_asst is not None else messages - prefix_res = await original( - self, - request, - prefix_msgs, - default_template, - default_template_content_format, - {**(default_template_kwargs or {}), "add_generation_prompt": False}, - tool_dicts=tool_dicts, - tool_parser=tool_parser, - ) - template_prefix_ids = prefix_res[1][0]["prompt_token_ids"] - - tok = self.renderer.get_tokenizer() - engine_prompt = res[1][0] - engine_prompt["prompt_token_ids"] = _replace_prefix_tokens( - required_prefix, - template_prefix_ids, - engine_prompt["prompt_token_ids"], - tok, - ) - except Exception as e: - logger.warning(f"[nemo-gym] retokenization patch failed, skipping: {e}") - return res - - return _patched - - -def patch_serving_chat_for_nemo_gym() -> None: - import importlib - - targets = { - "OpenAIServingChat": ( - "vllm.entrypoints.openai.chat_completion.serving", - "vllm.entrypoints.openai.chat_completion", - "vllm.entrypoints.openai.api_server", - "vllm.entrypoints.openai.serving_chat", - ), - "OpenAIServingTokenization": ( - "vllm.entrypoints.openai.api_server", - "vllm.entrypoints.serve.tokenize.serving", - ), - } - - patched_any = False - for cls_name, mods in targets.items(): - cls = None - for mod in mods: - try: - m = importlib.import_module(mod) - if hasattr(m, cls_name): - cls = getattr(m, cls_name) - break - except ImportError: - continue - if cls is None: - logger.warning(f"[nemo-gym] could not find {cls_name}; skipping.") - continue - cls._preprocess_chat = _make_patched_preprocess_chat(cls._preprocess_chat) - logger.warning(f"[nemo-gym] applied retokenization patch to {cls_name}.") - patched_any = True - - if not patched_any: - logger.warning("[nemo-gym] retokenization patch not applied to any serving class.") diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index ee0b8111b61..d33e40965df 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -78,13 +78,6 @@ class CustomAsyncServerConfig(BaseConfig): name: Optional[str] = None -@dataclass -class NemoGymConfig(BaseConfig): - nemo_gym_root: Optional[str] = None - uses_reasoning_parser: bool = False - config_paths: Optional[list] = None - - @dataclass class AgentLoopConfig(BaseConfig): num_workers: int = 8 @@ -94,7 +87,6 @@ class AgentLoopConfig(BaseConfig): # Fully qualified class name for custom AgentLoopManager (e.g., "mypackage.module.MyManager"). # Security: This class will be dynamically imported via importlib. Only use trusted class paths. agent_loop_manager_class: Optional[str] = None - nemo_gym: Optional[NemoGymConfig] = None @dataclass From 3a3d5367f484c3b02e3721fb7b60c693953b908b Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 4 Apr 2026 02:06:04 -0700 Subject: [PATCH 18/19] revert import Signed-off-by: cmunley1 --- verl/workers/config/rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index d33e40965df..886cb1f836e 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings from dataclasses import dataclass, field -from typing import Any, Optional +from typing import Optional from omegaconf import MISSING From f3880c9157ccef31a11b17216aa6640c0a01d862 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 4 Apr 2026 02:11:07 -0700 Subject: [PATCH 19/19] update patch import to recipe Signed-off-by: cmunley1 --- verl/workers/rollout/vllm_rollout/vllm_async_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 6e071a80bae..781c2bf9b28 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -172,7 +172,7 @@ def get_server_address(self): def apply_nemo_gym_server_patch(self): # called by NemoGymAgentLoopManager to apply retokenization fix only for nemo-gym runs - from verl.experimental.nemo_gym.server_patch import patch_serving_chat_for_nemo_gym + from recipe.nemo_gym.server_patch import patch_serving_chat_for_nemo_gym patch_serving_chat_for_nemo_gym() @property