diff --git a/nemo_gym/README.rst b/nemo_gym/README.rst new file mode 100644 index 0000000..5635658 --- /dev/null +++ b/nemo_gym/README.rst @@ -0,0 +1,180 @@ +NVIDIA NeMo Gym Integration +================================== + +`NVIDIA NeMo Gym `_ is an RL environment framework for +scalable, multi-environment, and agentic RL. Environments can be tested in NeMo Gym alone before +training with verl. Visit the `NeMo Gym docs `_ +to learn more. This recipe demonstrates offline rollout collection, and single and multi-environment +training on math and agentic workplace tasks with DAPO. + +Quickstart +---------- + +Local Rollout Collection +~~~~~~~~~~~~~~~~~~~~~~~~ + +**1. Clone repositories** + +.. code-block:: bash + + git clone https://github.com/verl-project/verl.git + git clone https://github.com/NVIDIA-NeMo/Gym.git + cd Gym + +**2. Set up NeMo Gym** + +.. code-block:: bash + + # Install uv if needed + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.local/bin/env + + export UV_CACHE_DIR=/path/to/cache # optional, useful on some + uv venv --python 3.12 + source .venv/bin/activate + uv sync --extra dev + +**3. Create an env.yaml with your policy model** + +For standalone testing, point at a local vllm instance (or an endpoint like OpenAI): + +.. code-block:: yaml + + # env.yaml + policy_base_url: https://localhost:8000/v1 + policy_api_key: empty + policy_model_name: Qwen/Qwen3-4B-Instruct-2507 + +**4. Start servers and test an environment** + +.. code-block:: bash + + config_paths="resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ + responses_api_models/vllm_model/configs/vllm_model.yaml" + + ng_run "+config_paths=[${config_paths}]" + +**5. Collect and inspect rollouts** + +In a separate terminal: + +.. code-block:: bash + + ng_collect_rollouts \ + +agent_name=workplace_assistant_simple_agent \ + +input_jsonl_fpath=resources_servers/workplace_assistant/data/example.jsonl \ + +output_jsonl_fpath=results/rollouts.jsonl \ + +limit=5 + + head -1 results/rollouts.jsonl | jq + +**6. Prepare training data** + +.. code-block:: bash + + config_paths="resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ + responses_api_models/vllm_model/configs/vllm_model_for_training.yaml" + + ng_prepare_data \ + "+config_paths=[${config_paths}]" \ + +output_dirpath=data/workplace_assistant \ + +mode=train_preparation \ + +should_download=true \ + +data_source=huggingface + +Check that each row has an ``agent_ref`` field. This is required for training. + +Training +~~~~~~~~ + +**7. Launch training** + +See ``submit_math.sh``, ``submit_workplace.sh``, or ``submit_multienv.sh`` for Slurm submission examples. The primary arguments relevant to NeMo Gym: + +.. code-block:: bash + + +data.custom_cls.path=recipe/nemo_gym/dataset.py + +data.custom_cls.name=NemoGymJSONLDataset + +actor_rollout_ref.rollout.agent.agent_loop_manager_class=recipe.nemo_gym.agent_loop.NemoGymAgentLoopManager + +actor_rollout_ref.rollout.agent.agent_loop_config_path=/path/to/configs/workplace.yaml + +Multi-Environment Training +-------------------------- + +To train on multiple environments simultaneously, create a mixed dataset where each row has an +``agent_ref`` pointing to its environment, and include all environment config paths: + +.. code-block:: yaml + + # configs/multienv.yaml + nemo_gym: + nemo_gym_root: $NEMO_GYM_ROOT + config_paths: + - $NEMO_GYM_ROOT/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml + - $NEMO_GYM_ROOT/resources_servers/math_with_judge/configs/math_with_judge.yaml + - $NEMO_GYM_ROOT/resources_servers/workplace_assistant/configs/workplace_assistant.yaml + +The first config tells verl and nemo gym to launch the model server, which tracks token IDs and log probs to prevent +retokenization mismatches and standardizes generation behind the OpenAI Responses API. + +The remaining configs define the environment. Each config specifies an agent server and optionally +a resources server that provides tools, state, verification, and reward logic. Some environments +use a ``responses_api_agents`` server only and do not have a separate resources server. + +The data blend determines the sampling ratio between environments. If environment curriculum or +precise blending is desired, do not shuffle the dataset after creation. NeMo Gym routes each row +to its environment via the ``agent_ref`` field. + +Note that some NeMo Gym environments such as SWE-RL launch containers and may require additional +setup (e.g. Apptainer). See each environment's README in the NeMo Gym repo for details. + +Overview +-------- + +- ``agent_loop.py`` — ``NemoGymAgentLoopManager``: wraps NeMo Gym's rollout collection interface + to collect rollouts for input tasks. Converts results to verl's DataProto format. +- ``dataset.py`` — ``NemoGymJSONLDataset``: loads NeMo Gym JSONL datasets. +- ``server_patch.py`` — patches vLLM's ``OpenAIServingChat`` and ``OpenAIServingTokenization`` + to correct for retokenization errors in multi-step rollouts, matching NeMo RL's approach. + **Tested with vLLM 0.17.0** (``verlai/verl:vllm017.latest``). The ``_preprocess_chat`` return + structure may change between vLLM versions — see comment in ``server_patch.py``. + +Requirements +------------ + +- A NeMo Gym clone with the environments you want to train on. +- ``pip install -e /path/to/nemo-gym`` in the container at job start. +- Container: ``verlai/verl:vllm017.latest`` (vLLM 0.17.0). + +Environment Variables +--------------------- + +The submit scripts source a ``config.env`` file for secrets and paths. Copy +``config.env.example`` and fill in your values: + +.. code-block:: bash + + cp recipe/nemo_gym/config.env.example config.env + +.. code-block:: bash + + VERL_ROOT=/path/to/verl + NEMO_GYM_ROOT=/path/to/nemo-gym + HF_HOME=/path/to/hf_home # Hugging Face model cache + RESULTS_ROOT=/path/to/results # checkpoints and rollout dumps + WANDB_USERNAME=your_username + WANDB_API_KEY=your_key + +Config YAML +----------- + +Each training run needs a config YAML (see ``configs/math.yaml`` for an example): + +.. code-block:: yaml + + nemo_gym: + nemo_gym_root: $NEMO_GYM_ROOT # path to NeMo Gym clone, expanded at runtime + uses_reasoning_parser: false # set true for reasoning models (e.g. DeepSeek-R1) + config_paths: + - $NEMO_GYM_ROOT/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml + - $NEMO_GYM_ROOT/resources_servers/your_env/configs/your_env.yaml diff --git a/nemo_gym/__init__.py b/nemo_gym/__init__.py new file mode 100644 index 0000000..1ce90c5 --- /dev/null +++ b/nemo_gym/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_gym/agent_loop.py b/nemo_gym/agent_loop.py new file mode 100644 index 0000000..9ecc699 --- /dev/null +++ b/nemo_gym/agent_loop.py @@ -0,0 +1,514 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import asyncio +import importlib.util as _ilu +import os +import socket +import sys +import threading +from collections import defaultdict +from pathlib import Path + +# TODO: remove in next nemo gym release +_nemo_gym_path = os.environ.get("NEMO_GYM_ROOT") +if not _nemo_gym_path: + _spec = _ilu.find_spec("nemo_gym") + if _spec and _spec.origin: + _nemo_gym_path = str(Path(_spec.origin).parent.parent) +if _nemo_gym_path: + sys.path.insert(0, _nemo_gym_path) + +# ruff: noqa: E402 +from typing import Optional + +import ray +import torch +from omegaconf import DictConfig, OmegaConf + +from verl.experimental.agent_loop.agent_loop import ( + AgentLoopManager, + AgentLoopMetrics, + _InternalAgentLoopOutput, +) +from verl.experimental.agent_loop.agent_loop import ( + AgentLoopWorker as _AgentLoopWorker, +) +from verl.protocol import DataProto +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.model import compute_position_id_with_mask +from verl.utils.ray_utils import auto_await + +# _postprocess is an AgentLoopWorker method that stacks _InternalAgentLoopOutput +# into a DataProto batch. It accesses self.distillation_enabled and +# self.reward_loop_worker_handles, both of which we set in _init_nemo_gym. +_postprocess = _AgentLoopWorker._postprocess + + +class NemoGymAgentLoopManager(AgentLoopManager): + @classmethod + @auto_await + async def create( + cls, + config, + worker_group=None, + rollout_resource_pool=None, + reward_loop_worker_handles=None, + teacher_model_manager=None, + ) -> NemoGymAgentLoopManager: + instance = cls( + config, + worker_group, + rollout_resource_pool, + teacher_model_manager, + reward_loop_worker_handles, + ) + await instance._initialize_llm_servers() + await instance._init_global_load_balancer() + await instance._apply_server_patch() + await instance._init_nemo_gym() + return instance + + async def _apply_server_patch(self) -> None: + futs = [handle.apply_nemo_gym_server_patch.remote() for handle in self.server_handles] + await asyncio.get_event_loop().run_in_executor(None, ray.get, futs) + + async def _init_nemo_gym(self) -> None: + # read nemo_gym config from agent_loop_config_path YAML under a "nemo_gym" key + # uses os.path.expandvars to resolve $NEMO_GYM_ROOT in the YAML + nemo_gym_cfg = {} + cfg_path = getattr(self.rollout_config.agent, "agent_loop_config_path", None) + if cfg_path: + with open(cfg_path) as f: + content = os.path.expandvars(f.read()) + yaml_cfg = OmegaConf.create(content) + nemo_gym_cfg = OmegaConf.to_container(yaml_cfg.get("nemo_gym", {}), resolve=False) + + _raw = nemo_gym_cfg.get("nemo_gym_root") or "" + nemo_gym_root = _raw if (_raw and not _raw.startswith("$")) else os.environ.get("NEMO_GYM_ROOT") + + try: + from nemo_gym.cli import GlobalConfigDictParserConfig, RunHelper + from nemo_gym.rollout_collection import RolloutCollectionHelper + from nemo_gym.server_utils import HEAD_SERVER_KEY_NAME, BaseServerConfig + except ModuleNotFoundError as e: + raise ImportError("nemo-gym not found. Install it with: pip install -e /path/to/gym-ref") from e + + if not nemo_gym_root: + spec = _ilu.find_spec("nemo_gym") + if spec and spec.origin: + nemo_gym_root = str(Path(spec.origin).parent.parent) + + config_paths = [ + p.replace("$NEMO_GYM_ROOT", str(nemo_gym_root)) if nemo_gym_root else p + for p in (nemo_gym_cfg.get("config_paths") or []) + ] + initial_global_cfg = {"config_paths": config_paths} if config_paths else {} + + uses_reasoning_parser = nemo_gym_cfg.get("uses_reasoning_parser", False) + vllm_model_cfg = ( + initial_global_cfg.setdefault("policy_model", {}) + .setdefault("responses_api_models", {}) + .setdefault("vllm_model", {}) + ) + vllm_model_cfg["uses_reasoning_parser"] = uses_reasoning_parser + + if not uses_reasoning_parser: + vllm_model_cfg.setdefault("extra_body", {}).setdefault("chat_template_kwargs", {})["enable_thinking"] = ( + False + ) + + base_urls = [ + (addr if addr.startswith("http") else f"http://{addr}").rstrip("/") + "/v1" + for addr in self.server_addresses + ] + initial_global_cfg["policy_model_name"] = self.model_config.get("path", "") + initial_global_cfg["policy_api_key"] = "dummy_key" + initial_global_cfg["policy_base_url"] = base_urls + initial_global_cfg.setdefault("global_aiohttp_connector_limit_per_host", 16_384) + initial_global_cfg.setdefault("global_aiohttp_connector_limit", 65_536) + + ray_context = ray.get_runtime_context() + initial_global_cfg["ray_head_node_address"] = ray_context.gcs_address + + if nemo_gym_root: + initial_global_cfg.setdefault("uv_venv_dir", str(nemo_gym_root)) + initial_global_cfg.setdefault("skip_venv_if_present", True) + + node_ip = ray._private.services.get_node_ip_address() + with socket.socket() as s: + s.bind(("", 0)) + head_port = s.getsockname()[1] + initial_global_cfg[HEAD_SERVER_KEY_NAME] = {"host": "0.0.0.0", "port": head_port} + self._head_server_config = BaseServerConfig(host=node_ip, port=head_port) + + self._rh = RunHelper() + self._rh.start( + global_config_dict_parser_config=GlobalConfigDictParserConfig( + dotenv_path=None, + initial_global_config_dict=DictConfig(initial_global_cfg), + skip_load_from_cli=True, + ) + ) + + self._rch = RolloutCollectionHelper() + + self._tokenizer = omega_conf_to_dataclass(self.model_config).tokenizer + self.distillation_enabled = False + + self._rollout_loop = asyncio.new_event_loop() + self._rollout_thread = threading.Thread( + target=self._rollout_loop.run_forever, + daemon=True, + name="nemo-gym-rollout-loop", + ) + self._rollout_thread.start() + + print(f"NemoGymAgentLoopManager ready: {len(base_urls)} vLLM endpoints: {base_urls}") + + def generate_sequences(self, prompts: DataProto) -> DataProto: + future = asyncio.run_coroutine_threadsafe(self._async_generate_sequences(prompts), self._rollout_loop) + return future.result() + + async def _async_generate_sequences(self, prompts: DataProto) -> DataProto: + validate = prompts.meta_info.get("validate", False) + global_steps = prompts.meta_info.get("global_steps", -1) + + nemo_gym_examples = _build_nemo_gym_examples( + prompts, + self.rollout_config, + validate=validate, + ) + + nemo_gym_result_iterator = self._rch.run_examples( + examples=nemo_gym_examples, + head_server_config=self._head_server_config, + ) + + rowidxs, raw_results = [], [] + for task in nemo_gym_result_iterator: + nemo_gym_row, nemo_gym_result = await task + try: + result = _postprocess_nemo_gym_result(nemo_gym_result, self._tokenizer) + except ValueError: + result = _empty_result(nemo_gym_row, self._tokenizer) + result["env"] = nemo_gym_row["agent_ref"]["name"] + rowidxs.append(nemo_gym_row["_rowidx"]) + raw_results.append(result) + + results = [None] * len(nemo_gym_examples) + for rowidx, result in zip(rowidxs, raw_results, strict=True): + results[rowidx] = result + missing = [i for i, r in enumerate(results) if r is None] + if missing: + raise RuntimeError(f"nemo-gym did not return results for samples: {missing}") + + prompt_lens = [sum(len(m["token_ids"]) for m in r["input_message_log"]) for r in results] + response_lens = [ + sum(len(m["token_ids"]) for m in r["message_log"][len(r["input_message_log"]) :]) for r in results + ] + prompt_length = max(prompt_lens) if prompt_lens else self.rollout_config.prompt_length + _vllm_kwargs = (getattr(self.rollout_config, "engine_kwargs", {}) or {}).get("vllm", {}) or {} + max_model_len = ( + getattr(self.rollout_config, "max_model_len", None) + or _vllm_kwargs.get("max-model-len") + or _vllm_kwargs.get("max_model_len") + ) + response_budget = (int(max_model_len) - prompt_length) if max_model_len is not None else None + response_length = max(response_lens) if response_lens else self.rollout_config.response_length + if response_budget is not None: + response_length = max(0, min(response_length, response_budget)) + + internal_outputs = [ + _nemo_gym_result_to_verl( + result=result, + tokenizer=self._tokenizer, + prompt_length=prompt_length, + response_length=response_length, + ) + for result in results + ] + + output = _postprocess( + self, + internal_outputs, + input_non_tensor_batch=prompts.non_tensor_batch, + validate=validate, + ) + output.meta_info["global_steps"] = global_steps + rollout_metrics = _compute_rollout_metrics(results, getattr(self.rollout_config, "max_model_len", None)) + output.meta_info["timing"] = {} + output.meta_info["rollout_metrics"] = rollout_metrics + + env_metrics = {k: v for k, v in rollout_metrics.items() if k.startswith("env/")} + if env_metrics: + try: + import wandb + + if wandb.run is not None: + wandb.log({**env_metrics, "train/global_step": global_steps}, step=global_steps) + except Exception: + pass + + return output + + +def _build_nemo_gym_examples( + prompts: DataProto, + rollout_config, + validate: bool = False, +) -> list[dict]: + cfg = rollout_config + temperature = cfg.val_kwargs.temperature if validate else cfg.temperature + top_p = cfg.val_kwargs.top_p if validate else cfg.top_p + + non_tensor = prompts.non_tensor_batch + examples = [] + for i in range(len(prompts)): + messages = list(non_tensor["raw_prompt"][i]) + + if "agent_ref" not in non_tensor: + raise ValueError(f"dataset row {i} is missing agent_ref") + agent_ref = non_tensor["agent_ref"][i] + + rcp = {} + if "extra_env_info" in non_tensor and "_rcp_extra" in (non_tensor["extra_env_info"][i] or {}): + rcp.update(non_tensor["extra_env_info"][i]["_rcp_extra"]) + rcp.update({"input": messages, "temperature": temperature, "top_p": top_p}) + + row = { + "responses_create_params": rcp, + "agent_ref": agent_ref, + "_rowidx": i, + } + + if "extra_env_info" in non_tensor: + env_info = non_tensor["extra_env_info"][i] + if isinstance(env_info, dict): + for k, v in env_info.items(): + if k not in row: + row[k] = v + + examples.append(row) + return examples + + +def _postprocess_nemo_gym_result(nemo_gym_result: dict, tokenizer) -> dict: + message_log = [] + seen_token_ids: list[int] = [] + + for item in nemo_gym_result["response"]["output"]: + if "generation_token_ids" not in item: + continue + + prompt_ids = item["prompt_token_ids"] + + if seen_token_ids != prompt_ids[: len(seen_token_ids)]: + raise ValueError( + f"Non-contiguous token IDs (server_patch active?). seen={len(seen_token_ids)} prompt={len(prompt_ids)}" + ) + + message_log.append( + { + "role": "user", + "content": "", + "token_ids": torch.tensor(prompt_ids[len(seen_token_ids) :]), + } + ) + message_log.append( + { + "role": "assistant", + "content": "", + "token_ids": torch.tensor(item["generation_token_ids"]), + "generation_logprobs": torch.tensor(item["generation_log_probs"]), + } + ) + + seen_token_ids.extend(message_log[-2]["token_ids"].tolist()) + seen_token_ids.extend(message_log[-1]["token_ids"].tolist()) + + item.pop("prompt_token_ids", None) + item["prompt_str"] = tokenizer.decode(prompt_ids) + item["generation_str"] = tokenizer.decode(item.pop("generation_token_ids")) + item.pop("generation_log_probs") + + if not message_log: + raise ValueError( + "nemo-gym returned a result with no generation data. The prompt may exceed vLLM's max_model_len." + ) + + return { + "message_log": message_log, + "input_message_log": message_log[:1], + "full_result": nemo_gym_result, + } + + +def _empty_result(nemo_gym_row: dict, tokenizer) -> dict: + messages = nemo_gym_row.get("responses_create_params", {}).get("input", []) + raw_prompt = [{"role": m.get("role", "user"), "content": m.get("content", "")} for m in messages] + prompt_ids = tokenizer.apply_chat_template(raw_prompt, tokenize=True, add_generation_prompt=False)[-1:] + dummy_tok = torch.tensor(prompt_ids, dtype=torch.long) + return { + "message_log": [ + {"role": "user", "token_ids": dummy_tok}, + {"role": "assistant", "token_ids": dummy_tok, "generation_logprobs": torch.zeros(len(dummy_tok))}, + ], + "input_message_log": [{"role": "user", "token_ids": dummy_tok}], + "full_result": {"reward": 0.0}, + } + + +def _nemo_gym_result_to_verl( + result: dict, + tokenizer, + prompt_length: int, + response_length: int, +) -> _InternalAgentLoopOutput: + """Pack message_log into padded tensors and mask non assistant messages""" + message_log = result["message_log"] + input_message_log = result["input_message_log"] + + prompt_ids_raw: list[int] = [] + for msg in input_message_log: + tids = msg["token_ids"] + prompt_ids_raw.extend(tids.tolist() if isinstance(tids, torch.Tensor) else tids) + + n_prompt_msgs = len(input_message_log) + response_ids_raw: list[int] = [] + response_mask_raw: list[int] = [] + response_logprobs_raw: list[float] = [] + + for msg in message_log[n_prompt_msgs:]: + tids = msg["token_ids"] + toks = tids.tolist() if isinstance(tids, torch.Tensor) else list(tids) + if msg["role"] == "assistant": + response_ids_raw.extend(toks) + response_mask_raw.extend([1] * len(toks)) + lp = msg.get("generation_logprobs") + response_logprobs_raw.extend( + lp.tolist() if isinstance(lp, torch.Tensor) else list(lp) if lp is not None else [0.0] * len(toks) + ) + else: + response_ids_raw.extend(toks) + response_mask_raw.extend([0] * len(toks)) + response_logprobs_raw.extend([0.0] * len(toks)) + + response_ids_raw = response_ids_raw[:response_length] + response_mask_raw = response_mask_raw[:response_length] + response_logprobs_raw = response_logprobs_raw[:response_length] + + tokenizer.padding_side = "left" + prompt_out = tokenizer.pad( + {"input_ids": prompt_ids_raw}, + padding="max_length", + max_length=prompt_length, + return_tensors="pt", + return_attention_mask=True, + ) + prompt_ids = prompt_out["input_ids"] + if prompt_ids.dim() == 1: + prompt_ids = prompt_ids.unsqueeze(0) + prompt_out["attention_mask"] = prompt_out["attention_mask"].unsqueeze(0) + + tokenizer.padding_side = "right" + response_out = tokenizer.pad( + {"input_ids": response_ids_raw}, + padding="max_length", + max_length=response_length, + return_tensors="pt", + return_attention_mask=True, + ) + response_ids = response_out["input_ids"] + response_attn = response_out["attention_mask"] + if response_ids.dim() == 1: + response_ids = response_ids.unsqueeze(0) + response_attn = response_attn.unsqueeze(0) + + pad = response_length - len(response_mask_raw) + response_mask = torch.tensor(response_mask_raw + [0] * pad, dtype=torch.long).unsqueeze(0) * response_attn + + pad = response_length - len(response_logprobs_raw) + response_logprobs = torch.tensor(response_logprobs_raw + [0.0] * pad, dtype=torch.float32).unsqueeze(0) + + attention_mask = torch.cat([prompt_out["attention_mask"], response_attn], dim=1) + input_ids = torch.cat([prompt_ids, response_ids], dim=1) + + position_ids = compute_position_id_with_mask(attention_mask) + + reward_score: Optional[float] = result.get("full_result", {}).get("reward", None) + num_turns = sum(1 for m in message_log if m["role"] == "assistant") + + return _InternalAgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids, + input_ids=input_ids, + position_ids=position_ids, + response_mask=response_mask, + attention_mask=attention_mask, + response_logprobs=response_logprobs, + routed_experts=None, + multi_modal_inputs={}, + teacher_logprobs=None, + teacher_ids=None, + reward_score=reward_score, + num_turns=num_turns, + metrics=AgentLoopMetrics(), + extra_fields={}, + ) + + +def _compute_rollout_metrics(results: list[dict], max_model_len: Optional[int] = None) -> dict: + batch_size = len(results) + if batch_size == 0: + return {} + + def _mean(vals): + return sum(vals) / len(vals) if vals else 0.0 + + stats = [] + for r in results: + ml = r["message_log"] + total = sum(len(m["token_ids"]) for m in ml) + asst = sum(len(m["token_ids"]) for m in ml if m["role"] == "assistant") + turns = sum(1 for m in ml if m["role"] == "user") + hit_max = (max_model_len is not None) and (total >= max_model_len) + stats.append( + { + "reward": float(r.get("full_result", {}).get("reward", 0.0)), + "asst": asst, + "total": total, + "turns": turns, + "hit_max": hit_max, + } + ) + + metrics = { + "turns_per_sample/mean": _mean([s["turns"] for s in stats]), + "total_tokens_per_sample/mean": _mean([s["total"] for s in stats]), + "gen_tokens_per_sample/mean": _mean([s["asst"] for s in stats]), + "total_reward/mean": _mean([s["reward"] for s in stats]), + "natural_termination_rate": sum(not s["hit_max"] for s in stats) / batch_size, + "truncation_rate": sum(s["hit_max"] for s in stats) / batch_size, + } + + # per environment metrics + env_rewards: dict = defaultdict(list) + for r, s in zip(results, stats, strict=True): + env_rewards[r.get("env", "unknown")].append(s["reward"]) + for env, rewards in env_rewards.items(): + metrics[f"env/{env}/reward_mean"] = _mean(rewards) + + return metrics diff --git a/nemo_gym/config.env.example b/nemo_gym/config.env.example new file mode 100644 index 0000000..c4eb0aa --- /dev/null +++ b/nemo_gym/config.env.example @@ -0,0 +1,6 @@ +VERL_ROOT=/path/to/verl +NEMO_GYM_ROOT=/path/to/nemo-gym +HF_HOME=/path/to/hf_home # Hugging Face model cache +RESULTS_ROOT=/path/to/results # checkpoints and rollout dumps +WANDB_USERNAME=your_wandb_username +WANDB_API_KEY=your_key_here diff --git a/nemo_gym/configs/math.yaml b/nemo_gym/configs/math.yaml new file mode 100644 index 0000000..a1c683a --- /dev/null +++ b/nemo_gym/configs/math.yaml @@ -0,0 +1,6 @@ +nemo_gym: + nemo_gym_root: $NEMO_GYM_ROOT + uses_reasoning_parser: false + config_paths: + - $NEMO_GYM_ROOT/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml + - $NEMO_GYM_ROOT/resources_servers/math_with_judge/configs/math_with_judge.yaml diff --git a/nemo_gym/configs/multienv.yaml b/nemo_gym/configs/multienv.yaml new file mode 100644 index 0000000..49e50b0 --- /dev/null +++ b/nemo_gym/configs/multienv.yaml @@ -0,0 +1,7 @@ +nemo_gym: + nemo_gym_root: $NEMO_GYM_ROOT + uses_reasoning_parser: false + config_paths: + - $NEMO_GYM_ROOT/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml + - $NEMO_GYM_ROOT/resources_servers/workplace_assistant/configs/workplace_assistant.yaml + - $NEMO_GYM_ROOT/resources_servers/math_with_judge/configs/math_with_judge.yaml diff --git a/nemo_gym/configs/workplace.yaml b/nemo_gym/configs/workplace.yaml new file mode 100644 index 0000000..1d6af5e --- /dev/null +++ b/nemo_gym/configs/workplace.yaml @@ -0,0 +1,6 @@ +nemo_gym: + nemo_gym_root: $NEMO_GYM_ROOT + uses_reasoning_parser: false + config_paths: + - $NEMO_GYM_ROOT/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml + - $NEMO_GYM_ROOT/resources_servers/workplace_assistant/configs/workplace_assistant.yaml diff --git a/nemo_gym/dataset.py b/nemo_gym/dataset.py new file mode 100644 index 0000000..ab29089 --- /dev/null +++ b/nemo_gym/dataset.py @@ -0,0 +1,84 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from pathlib import Path + +import torch +from torch.utils.data import Dataset + + +class NemoGymJSONLDataset(Dataset): + def __init__( + self, + data_files: str | list[str], + tokenizer, + processor=None, + config=None, + **kwargs, + ): + if isinstance(data_files, str): + data_files = [data_files] + + self.tokenizer = tokenizer + self._rows: list[dict] = [] + + for path in data_files: + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"NemoGymJSONLDataset: file not found: {path}") + with open(path) as f: + for line in f: + line = line.strip() + if line: + self._rows.append(json.loads(line)) + + def __len__(self) -> int: + return len(self._rows) + + def __getitem__(self, idx: int) -> dict: + row = self._rows[idx] + + rcp = row.get("responses_create_params", {}) + messages = rcp.get("input", []) + raw_prompt = [{"role": m.get("role", "user"), "content": m.get("content", "")} for m in messages] + + agent_ref = row.get("agent_ref", None) + + skip_keys = {"responses_create_params", "agent_ref"} + extra_env_info = {k: v for k, v in row.items() if k not in skip_keys} + + # preserve per-task fields like `tools` to nemo-gym request + # input/temperature/top_p/top_k are overridden per training step + # parallel_tool_calls is dropped because vLLM throws 500 error + _rcp_skip = {"input", "temperature", "top_p", "top_k", "parallel_tool_calls"} + rcp_extra = {k: v for k, v in rcp.items() if k not in _rcp_skip} + if rcp_extra: + extra_env_info["_rcp_extra"] = rcp_extra + + out = { + "raw_prompt": raw_prompt, + # unused placeholder. ray_trainer.py calls len(batch.batch) which requires batch to be non-empty + "__nemo_gym_batch_size__": torch.zeros(1, dtype=torch.long), + } + if agent_ref is not None: + out["agent_ref"] = agent_ref + out["extra_env_info"] = extra_env_info + + return out + + @property + def collate_fn(self): + from verl.utils.dataset.rl_dataset import collate_fn + + return collate_fn diff --git a/nemo_gym/server_patch.py b/nemo_gym/server_patch.py new file mode 100644 index 0000000..a707ae4 --- /dev/null +++ b/nemo_gym/server_patch.py @@ -0,0 +1,128 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +logger = logging.getLogger(__name__) + + +def _replace_prefix_tokens(model_prefix, template_prefix, template_ids, tok): + # matches NeMo-RL implementation + if not model_prefix: + return template_ids + eos = tok.eos_token_id + if eos is None: + raise ValueError("tokenizer must have eos_token_id") + cut_model = len(model_prefix) + if model_prefix[-1] == eos: + cut_model -= 1 + if len(template_ids) <= len(template_prefix): + raise ValueError( + f"non-monotonically increasing trajectory: " + f"template_ids={len(template_ids)} template_prefix={len(template_prefix)}" + ) + cut = -1 + for pos in reversed(range(len(template_prefix))): + if template_ids[pos] == eos: + cut = pos + break + if cut < 0: + raise ValueError("no EOS token found in chat-templated messages") + return model_prefix[:cut_model] + template_ids[cut:] + + +def _make_patched_preprocess_chat(original): + async def _patched( + self, + request, + messages, + default_template, + default_template_content_format, + default_template_kwargs, + tool_dicts=None, + tool_parser=None, + ): + required_prefix = getattr(request, "required_prefix_token_ids", None) + if required_prefix is None: + for msg in reversed(messages): + if isinstance(msg, dict) and "prompt_token_ids" in msg: + required_prefix = list(msg["prompt_token_ids"]) + list(msg["generation_token_ids"]) + break + elif not isinstance(msg, dict) and getattr(msg, "prompt_token_ids", None): + required_prefix = list(msg.prompt_token_ids) + list(msg.generation_token_ids) + break + + try: + res = await original( + self, + request, + messages, + default_template, + default_template_content_format, + default_template_kwargs, + tool_dicts=tool_dicts, + tool_parser=tool_parser, + ) + except ValueError as e: + if "maximum context length" in str(e): + logger.warning("Prompt exceeds max_model_len: %s", e) + raise + + if required_prefix is None: + return res + + last_asst = next( + ( + i + for i in reversed(range(len(messages))) + if (messages[i].get("role") if isinstance(messages[i], dict) else getattr(messages[i], "role", None)) + == "assistant" + ), + None, + ) + prefix_msgs = messages[: last_asst + 1] if last_asst is not None else messages + prefix_res = await original( + self, + request, + prefix_msgs, + default_template, + default_template_content_format, + {**(default_template_kwargs or {}), "add_generation_prompt": False}, + tool_dicts=tool_dicts, + tool_parser=tool_parser, + ) + # tested on vLLM 0.17.0. other versions may error + template_prefix_ids = prefix_res[1][0]["prompt_token_ids"] + + tok = self.renderer.get_tokenizer() + engine_prompt = res[1][0] + engine_prompt["prompt_token_ids"] = _replace_prefix_tokens( + required_prefix, + template_prefix_ids, + engine_prompt["prompt_token_ids"], + tok, + ) + return res + + return _patched + + +def patch_serving_chat_for_nemo_gym() -> None: + # vLLM 0.17.0 module paths + from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat + from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization + + for cls in (OpenAIServingChat, OpenAIServingTokenization): + cls._preprocess_chat = _make_patched_preprocess_chat(cls._preprocess_chat) + logger.warning(f"[nemo-gym] applied retokenization patch to {cls.__name__}.") diff --git a/nemo_gym/submit_math.sh b/nemo_gym/submit_math.sh new file mode 100755 index 0000000..1421f38 --- /dev/null +++ b/nemo_gym/submit_math.sh @@ -0,0 +1,166 @@ +#!/bin/bash +#SBATCH --job-name=verl-nemogym-dapo-7b-math +#SBATCH --nodes=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --partition=your_partition +#SBATCH --account=your_account +#SBATCH --time=4:00:00 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --output=logs/slurm-%j.out +#SBATCH --error=logs/slurm-%j.err + +set -euo pipefail + +GPUS_PER_NODE=8 + +source "${SLURM_SUBMIT_DIR}/config.env" + +MODEL_PATH="/path/to/Qwen3-4B-Instruct" +TRAIN_FILE="/path/to/math_with_judge/dapo17k_bytedtsinghua_train_nrl.jsonl" +TEST_FILE="/path/to/math_with_judge/aime24_bytedtsinghua_validation_nrl.jsonl" +CKPTS_DIR="${RESULTS_ROOT}/DAPO-Qwen2.5-7b-MATH-megatron" + +CONTAINER="verlai/verl:vllm017.latest" +MOUNTS="/lustre:/lustre" + +mkdir -p "${CKPTS_DIR}" + +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address | awk '{print $1}') + +RAY_PORT=6379 +ip_head="${head_node_ip}:${RAY_PORT}" +echo "Head node: ${head_node} (${head_node_ip})" + +SRUN_ARGS="--no-container-mount-home --container-image=${CONTAINER} --container-mounts=${MOUNTS} --container-workdir=${VERL_ROOT}" + +echo "Starting Ray head on ${head_node}..." +srun --nodes=1 --ntasks=1 -w "${head_node}" ${SRUN_ARGS} --container-name=ray-head \ + env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start --head \ + --node-ip-address="${head_node_ip}" \ + --port=${RAY_PORT} \ + --num-gpus="${GPUS_PER_NODE}" \ + --block & +sleep 10 + +worker_num=$((SLURM_JOB_NUM_NODES - 1)) +for ((i = 1; i <= worker_num; i++)); do + node_i=${nodes_array[$i]} + echo "Starting Ray worker ${i} on ${node_i}..." + srun --nodes=1 --ntasks=1 -w "${node_i}" ${SRUN_ARGS} \ + env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start \ + --address="${ip_head}" \ + --num-gpus="${GPUS_PER_NODE}" \ + --block & + sleep 5 +done + +CONTAINER_DIR="/raid/enroot/data/user-${UID}/pyxis_${SLURM_JOB_ID}_ray-head" +echo "Waiting for ray-head container at ${CONTAINER_DIR}..." +elapsed=0 +while [[ ! -d "${CONTAINER_DIR}" && ${elapsed} -lt 300 ]]; do + sleep 5 + elapsed=$((elapsed + 5)) +done +if [[ ! -d "${CONTAINER_DIR}" ]]; then + echo "ERROR: ray-head container never appeared after 300s" + exit 1 +fi +echo "Container ready. Waiting 90s for all Ray workers to connect..." +sleep 90 + +echo "Installing nemo-gym..." +srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ + --no-container-mount-home --container-mounts=${MOUNTS} \ + --container-name=ray-head \ + bash -c "touch ${NEMO_GYM_ROOT}/scripts/__init__.py && pip install -q uv && echo 'blinker==1.4' > /tmp/constraints.txt && pip install -q -e ${NEMO_GYM_ROOT} -c /tmp/constraints.txt" + +echo "Launching training on ${head_node}..." +PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ + --no-container-mount-home --container-mounts=${MOUNTS} \ + --container-workdir=${VERL_ROOT} --container-name=ray-head \ + env -u ROCR_VISIBLE_DEVICES \ + WANDB_API_KEY="${WANDB_API_KEY}" \ + HF_HOME="${HF_HOME}" \ + HF_HUB_CACHE="${HF_HOME}/hub" \ + RAY_ADDRESS="auto" \ + VLLM_USE_V1=1 \ + VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ + NEMO_GYM_ROOT="${NEMO_GYM_ROOT}" \ + PYTHONPATH="${NEMO_GYM_ROOT}:${VERL_ROOT}" \ + RAY_grpc_keepalive_time_ms=60000 \ + RAY_grpc_keepalive_timeout_ms=600000 \ + RAY_grpc_client_keepalive_time_ms=60000 \ + RAY_grpc_client_keepalive_timeout_ms=600000 \ + python3 -m verl.trainer.main_ppo \ + --config-path="${VERL_ROOT}/recipe/dapo/config" \ + --config-name=dapo_megatron_trainer.yaml \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + +data.custom_cls.path="${VERL_ROOT}/recipe/nemo_gym/dataset.py" \ + +data.custom_cls.name=NemoGymJSONLDataset \ + data.truncation=left \ + data.train_batch_size=32 \ + actor_rollout_ref.rollout.n=16 \ + algorithm.adv_estimator=grpo \ + algorithm.use_kl_in_reward=False \ + algorithm.kl_ctrl.kl_coef=0.0 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.28 \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=token-mean \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=10240 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=1.0 \ + actor_rollout_ref.rollout.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ + actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + '+actor_rollout_ref.rollout.engine_kwargs.vllm.max-model-len=32768' \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.ref.megatron.param_offload=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.max_resp_len=32768 \ + 'trainer.logger=["console","wandb"]' \ + trainer.project_name=${WANDB_USERNAME}-verl-nemogym-int \ + trainer.experiment_name=dapo-qwen3-4b-math \ + trainer.n_gpus_per_node=${GPUS_PER_NODE} \ + trainer.nnodes=${SLURM_JOB_NUM_NODES} \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable \ + trainer.log_val_generations=10 \ + +actor_rollout_ref.rollout.agent.agent_loop_manager_class='recipe.nemo_gym.agent_loop.NemoGymAgentLoopManager' \ + +actor_rollout_ref.rollout.agent.agent_loop_config_path="${VERL_ROOT}/recipe/nemo_gym/configs/math.yaml" \ + 2>&1 diff --git a/nemo_gym/submit_multienv.sh b/nemo_gym/submit_multienv.sh new file mode 100755 index 0000000..f4731b7 --- /dev/null +++ b/nemo_gym/submit_multienv.sh @@ -0,0 +1,179 @@ +#!/bin/bash +#SBATCH --job-name=verl-nemogym-dapo-multienv +#SBATCH --nodes=16 +#SBATCH --ntasks-per-node=1 +#SBATCH --partition=your_partition +#SBATCH --account=your_account +#SBATCH --time=4:00:00 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --output=logs/slurm-%j.out +#SBATCH --error=logs/slurm-%j.err + +set -euo pipefail + +GPUS_PER_NODE=8 + +source "${SLURM_SUBMIT_DIR}/config.env" + +MODEL_PATH="/path/to/Qwen3-4B-Instruct" +TRAIN_FILE="/path/to/multienv/train.jsonl" +TEST_FILE="/path/to/multienv/validation.jsonl" +CKPTS_DIR="${RESULTS_ROOT}/dapo-qwen3-4b-multienv" +ROLLOUT_DIR="${RESULTS_ROOT}/dapo-qwen3-4b-multienv-rollouts" + +CONTAINER="verlai/verl:vllm017.latest" +MOUNTS="/lustre:/lustre" + +mkdir -p "${CKPTS_DIR}" "${ROLLOUT_DIR}" + +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address | awk '{print $1}') + +RAY_PORT=6379 +ip_head="${head_node_ip}:${RAY_PORT}" +echo "Head node: ${head_node} (${head_node_ip})" + +SRUN_ARGS="--no-container-mount-home --container-image=${CONTAINER} --container-mounts=${MOUNTS} --container-workdir=${VERL_ROOT}" + +echo "Starting Ray head on ${head_node}..." +srun --nodes=1 --ntasks=1 -w "${head_node}" ${SRUN_ARGS} --container-name=ray-head \ + env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" \ + NEMO_GYM_ROOT="${NEMO_GYM_ROOT}" \ + PYTHONPATH="${NEMO_GYM_ROOT}:${VERL_ROOT}" \ + ray start --head \ + --node-ip-address="${head_node_ip}" \ + --port=${RAY_PORT} \ + --num-gpus="${GPUS_PER_NODE}" \ + --block & +sleep 10 + +worker_num=$((SLURM_JOB_NUM_NODES - 1)) +for ((i = 1; i <= worker_num; i++)); do + node_i=${nodes_array[$i]} + echo "Starting Ray worker ${i} on ${node_i}..." + srun --nodes=1 --ntasks=1 -w "${node_i}" ${SRUN_ARGS} \ + env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" \ + NEMO_GYM_ROOT="${NEMO_GYM_ROOT}" \ + PYTHONPATH="${NEMO_GYM_ROOT}:${VERL_ROOT}" \ + ray start \ + --address="${ip_head}" \ + --num-gpus="${GPUS_PER_NODE}" \ + --block & + sleep 5 +done + +CONTAINER_DIR="/raid/enroot/data/user-${UID}/pyxis_${SLURM_JOB_ID}_ray-head" +echo "Waiting for ray-head container at ${CONTAINER_DIR}..." +elapsed=0 +while [[ ! -d "${CONTAINER_DIR}" && ${elapsed} -lt 300 ]]; do + sleep 5 + elapsed=$((elapsed + 5)) +done +if [[ ! -d "${CONTAINER_DIR}" ]]; then + echo "ERROR: ray-head container never appeared after 300s" + exit 1 +fi +echo "Container ready. Waiting 90s for all Ray workers to connect..." +sleep 90 + +echo "Installing nemo-gym..." +srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ + --no-container-mount-home --container-mounts=${MOUNTS} \ + --container-name=ray-head \ + bash -c "PYTHONPATH= touch ${NEMO_GYM_ROOT}/scripts/__init__.py && pip install -q uv && echo 'blinker==1.4' > /tmp/constraints.txt && pip install -q -e ${NEMO_GYM_ROOT} -c /tmp/constraints.txt" + +# TODO: test if hermes tool parser still hits "already borrowed" tokenizer errors under concurrent load +# if so, point to or provide the patch here, or use a different model+tool parser + +echo "Launching training on ${head_node}..." +PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ + --no-container-mount-home --container-mounts=${MOUNTS} \ + --container-workdir=${VERL_ROOT} --container-name=ray-head \ + env -u ROCR_VISIBLE_DEVICES \ + WANDB_API_KEY="${WANDB_API_KEY}" \ + HF_HOME="${HF_HOME}" \ + HF_HUB_CACHE="${HF_HOME}/hub" \ + RAY_ADDRESS="auto" \ + VLLM_USE_V1=1 \ + TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ + NEMO_GYM_ROOT="${NEMO_GYM_ROOT}" \ + PYTHONPATH="${NEMO_GYM_ROOT}:${VERL_ROOT}" \ + VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + RAY_grpc_keepalive_time_ms=60000 \ + RAY_grpc_keepalive_timeout_ms=600000 \ + RAY_grpc_client_keepalive_time_ms=60000 \ + RAY_grpc_client_keepalive_timeout_ms=600000 \ + python3 -m verl.trainer.main_ppo \ + --config-path="${VERL_ROOT}/recipe/dapo/config" \ + --config-name=dapo_megatron_trainer.yaml \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + +data.custom_cls.path="${VERL_ROOT}/recipe/nemo_gym/dataset.py" \ + +data.custom_cls.name=NemoGymJSONLDataset \ + data.truncation=left \ + data.train_batch_size=32 \ + actor_rollout_ref.rollout.n=16 \ + algorithm.adv_estimator=grpo \ + algorithm.use_kl_in_reward=False \ + algorithm.kl_ctrl.kl_coef=0.0 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.28 \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=token-mean \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=10240 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=1.0 \ + actor_rollout_ref.rollout.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ + actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + '+actor_rollout_ref.rollout.engine_kwargs.vllm.enable-auto-tool-choice=true' \ + '+actor_rollout_ref.rollout.engine_kwargs.vllm.tool-call-parser=hermes' \ + '+actor_rollout_ref.rollout.engine_kwargs.vllm.max-model-len=32768' \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.ref.megatron.param_offload=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.max_resp_len=32768 \ + 'trainer.logger=["console","wandb"]' \ + trainer.project_name=${WANDB_USERNAME}-verl-nemogym-int \ + trainer.experiment_name=dapo-qwen3-4b-multienv \ + trainer.n_gpus_per_node=${GPUS_PER_NODE} \ + trainer.nnodes=${SLURM_JOB_NUM_NODES} \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable \ + trainer.log_val_generations=10 \ + +trainer.rollout_data_dir="${ROLLOUT_DIR}" \ + +actor_rollout_ref.rollout.agent.agent_loop_manager_class='recipe.nemo_gym.agent_loop.NemoGymAgentLoopManager' \ + +actor_rollout_ref.rollout.agent.agent_loop_config_path="${VERL_ROOT}/recipe/nemo_gym/configs/multienv.yaml" \ + 2>&1 diff --git a/nemo_gym/submit_workplace.sh b/nemo_gym/submit_workplace.sh new file mode 100644 index 0000000..b9df6dc --- /dev/null +++ b/nemo_gym/submit_workplace.sh @@ -0,0 +1,171 @@ +#!/bin/bash +#SBATCH --job-name=verl-nemogym-dapo-4b-workplace +#SBATCH --nodes=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --partition=your_partition +#SBATCH --account=your_account +#SBATCH --time=4:00:00 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --output=logs/slurm-%j.out +#SBATCH --error=logs/slurm-%j.err + +set -euo pipefail + +GPUS_PER_NODE=8 + +source "${SLURM_SUBMIT_DIR}/config.env" + +MODEL_PATH="/path/to/Qwen3-4B-Instruct" +TRAIN_FILE="/path/to/workplace_assistant/train.jsonl" +TEST_FILE="/path/to/workplace_assistant/validation.jsonl" +CKPTS_DIR="${RESULTS_ROOT}/dapo-qwen3-4b-workplace" + +CONTAINER="verlai/verl:vllm017.latest" +MOUNTS="/lustre:/lustre" + +mkdir -p "${CKPTS_DIR}" + +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address | awk '{print $1}') + +RAY_PORT=6379 +ip_head="${head_node_ip}:${RAY_PORT}" +echo "Head node: ${head_node} (${head_node_ip})" + +SRUN_ARGS="--no-container-mount-home --container-image=${CONTAINER} --container-mounts=${MOUNTS} --container-workdir=${VERL_ROOT}" + +echo "Starting Ray head on ${head_node}..." +srun --nodes=1 --ntasks=1 -w "${head_node}" ${SRUN_ARGS} --container-name=ray-head \ + env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start --head \ + --node-ip-address="${head_node_ip}" \ + --port=${RAY_PORT} \ + --num-gpus="${GPUS_PER_NODE}" \ + --block & +sleep 10 + +worker_num=$((SLURM_JOB_NUM_NODES - 1)) +for ((i = 1; i <= worker_num; i++)); do + node_i=${nodes_array[$i]} + echo "Starting Ray worker ${i} on ${node_i}..." + srun --nodes=1 --ntasks=1 -w "${node_i}" ${SRUN_ARGS} \ + env -u ROCR_VISIBLE_DEVICES WANDB_API_KEY="${WANDB_API_KEY}" ray start \ + --address="${ip_head}" \ + --num-gpus="${GPUS_PER_NODE}" \ + --block & + sleep 5 +done + +CONTAINER_DIR="/raid/enroot/data/user-${UID}/pyxis_${SLURM_JOB_ID}_ray-head" +echo "Waiting for ray-head container at ${CONTAINER_DIR}..." +elapsed=0 +while [[ ! -d "${CONTAINER_DIR}" && ${elapsed} -lt 300 ]]; do + sleep 5 + elapsed=$((elapsed + 5)) +done +if [[ ! -d "${CONTAINER_DIR}" ]]; then + echo "ERROR: ray-head container never appeared after 300s" + exit 1 +fi +echo "Container ready. Waiting 90s for all Ray workers to connect..." +sleep 90 + +echo "Installing nemo-gym..." +srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ + --no-container-mount-home --container-mounts=${MOUNTS} \ + --container-name=ray-head \ + bash -c "PYTHONPATH= touch ${NEMO_GYM_ROOT}/scripts/__init__.py && pip install -q uv && echo 'blinker==1.4' > /tmp/constraints.txt && pip install -q -e ${NEMO_GYM_ROOT} -c /tmp/constraints.txt" + +# TODO: test if hermes tool parser still hits "already borrowed" tokenizer errors under concurrent load +# if so, point to or provide the patch here, or use a different model+tool parser + +echo "Launching training on ${head_node}..." +PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "${head_node}" \ + --no-container-mount-home --container-mounts=${MOUNTS} \ + --container-workdir=${VERL_ROOT} --container-name=ray-head \ + env -u ROCR_VISIBLE_DEVICES \ + WANDB_API_KEY="${WANDB_API_KEY}" \ + HF_HOME="${HF_HOME}" \ + HF_HUB_CACHE="${HF_HOME}/hub" \ + RAY_ADDRESS="auto" \ + VLLM_USE_V1=1 \ + TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ + NEMO_GYM_ROOT="${NEMO_GYM_ROOT}" \ + PYTHONPATH="${NEMO_GYM_ROOT}:${VERL_ROOT}" \ + VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + RAY_grpc_keepalive_time_ms=60000 \ + RAY_grpc_keepalive_timeout_ms=600000 \ + RAY_grpc_client_keepalive_time_ms=60000 \ + RAY_grpc_client_keepalive_timeout_ms=600000 \ + python3 -m verl.trainer.main_ppo \ + --config-path="${VERL_ROOT}/recipe/dapo/config" \ + --config-name=dapo_megatron_trainer.yaml \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + +data.custom_cls.path="${VERL_ROOT}/recipe/nemo_gym/dataset.py" \ + +data.custom_cls.name=NemoGymJSONLDataset \ + data.truncation=left \ + data.train_batch_size=32 \ + actor_rollout_ref.rollout.n=16 \ + algorithm.adv_estimator=grpo \ + algorithm.use_kl_in_reward=False \ + algorithm.kl_ctrl.kl_coef=0.0 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.28 \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=token-mean \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=10240 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=1.0 \ + actor_rollout_ref.rollout.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ + actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + '+actor_rollout_ref.rollout.engine_kwargs.vllm.enable-auto-tool-choice=true' \ + '+actor_rollout_ref.rollout.engine_kwargs.vllm.tool-call-parser=hermes' \ + '+actor_rollout_ref.rollout.engine_kwargs.vllm.max-model-len=32768' \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.ref.megatron.param_offload=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.max_resp_len=32768 \ + 'trainer.logger=["console","wandb"]' \ + trainer.project_name=${WANDB_USERNAME}-verl-nemogym-int \ + trainer.experiment_name=dapo-qwen3-4b-workplace \ + trainer.n_gpus_per_node=${GPUS_PER_NODE} \ + trainer.nnodes=${SLURM_JOB_NUM_NODES} \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable \ + trainer.log_val_generations=10 \ + +actor_rollout_ref.rollout.agent.agent_loop_manager_class='recipe.nemo_gym.agent_loop.NemoGymAgentLoopManager' \ + +actor_rollout_ref.rollout.agent.agent_loop_config_path="${VERL_ROOT}/recipe/nemo_gym/configs/workplace.yaml" \ + 2>&1