Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions examples/agent_train/train_qwen3_moe_rc.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
#!/usr/bin/env bash
set -xeuo pipefail

RAY_DATA_HOME=/mnt/hdfs/yyding
NNODES_ROLLOUT=16
NNODES_TRAIN=4
GEN_TP=2

project_name=${PROJECT_NAME:-'Uni-Agent-Qwen3-Coder-30B-megatron'}
exp_name=${EXP_NAME:-"$(date +%Y%m%d%H)_exp"}

RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-Coder-30B-A3B-Instruct"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/swe_agent/swe_rebench_v2_modal.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/swe_agent/swe_bench_verified.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/swe_agent/swe_bench_verified_modal.parquet"}
RUNTIME_ENV=${RUNTIME_ENV:-"${RAY_DATA_HOME}/data/swe_agent/runtime_env.yaml"}
# Must be launched from the repository root so Ray packages both `verl/` and `uni_agent/`.
AGENT_CONFIG_PATH=${AGENT_CONFIG_PATH:-"${RAY_DATA_HOME}/data/swe_agent/agent_config.yaml"}
Expand Down Expand Up @@ -70,8 +75,8 @@ NNODES_TRAIN=${NNODES_TRAIN:-4}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}

train_prompt_bsz=0
n_resp_per_prompt=${N_RESP_PER_PROMPT:-16}
train_prompt_mini_bsz=${PPO_MINI_BATCH_SIZE:-8}
n_resp_per_prompt=${N_RESP_PER_PROMPT:-8}
train_prompt_mini_bsz=${PPO_MINI_BATCH_SIZE:-16}
total_rollout_steps=${TOTAL_ROLLOUT_STEPS:-10000}
test_freq=${TEST_FREQ:-10}
staleness_threshold=${STALENESS_THRESHOLD:-1.0}
Expand All @@ -95,8 +100,7 @@ rollout_rs_threshold=${ROLLOUT_RS_THRESHOLD:-"0.999_1.001"} # k1: "lo_hi" r
router_replay_mode=${ROUTER_REPLAY_MODE:-R3} # disabled | R2 | R3
enable_rollout_routing_replay=${ENABLE_ROLLOUT_ROUTING_REPLAY:-True} # required for R3 (rollout-side replay)

ray job submit --no-wait --runtime-env $RUNTIME_ENV \
-- python3 -m verl.experimental.fully_async_policy.fully_async_main \
python3 -m verl.experimental.fully_async_policy.fully_async_main \
--config-name='fully_async_ppo_megatron_trainer.yaml' \
hydra.searchpath=[pkg://verl.trainer.config] \
data.train_files="${TRAIN_FILE}" \
Expand Down
44 changes: 43 additions & 1 deletion uni_agent/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any

import numpy as np
import yaml

from uni_agent.async_logging import add_file_handler, get_logger
Expand Down Expand Up @@ -43,6 +44,12 @@ def _deep_merge(base: dict, overrides: dict) -> dict:

class UniAgentLoop(AgentLoopBase):
_semaphore: asyncio.Semaphore | None = None
# Cached (num_hidden_layers, num_experts_per_tok) of the rollout model. Used to
# synthesize a zero ``routed_experts`` for failed/empty trajectories when router
# replay (R3) is enabled. ``None`` after resolution means no replay tensor is needed
# (replay disabled or the model is dense / not MoE).
_routing_replay_shape: tuple[int, int] | None = None
_routing_replay_resolved: bool = False

async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
config_dict = self._init_config(sampling_params, **kwargs)
Expand Down Expand Up @@ -159,14 +166,49 @@ async def _build_empty_agent_output(self, exit_reason: str) -> AgentLoopOutput:
response_ids=[dummy_token_id] * dummy_response_length,
response_mask=[0] * dummy_response_length,
response_logprobs=[0.0] * dummy_response_length,
routed_experts=None,
routed_experts=self._synth_failed_routed_experts(dummy_response_length),
multi_modal_data={},
reward_score=0,
num_turns=0,
metrics={},
extra_fields=extra_fields,
)

def _synth_failed_routed_experts(self, length: int) -> np.ndarray | None:
"""Synthesize a zero ``routed_experts`` of shape ``(length, num_layers, top_k)``."""
shape = self._get_routing_replay_shape()
if shape is None:
return None
num_layers, top_k = shape
return np.zeros((length, num_layers, top_k), dtype=np.int64)

def _get_routing_replay_shape(self) -> tuple[int, int] | None:
"""Resolve and cache ``(num_hidden_layers, num_experts_per_tok)`` for the rollout
model. Returns ``None`` if rollout routing replay is off or the model has no
experts. The HF config is loaded at most once per worker process."""
rollout_cfg = self.config.actor_rollout_ref.rollout
if not bool(getattr(rollout_cfg, "enable_rollout_routing_replay", False)):
return None
cls = UniAgentLoop
if not cls._routing_replay_resolved:
from transformers import AutoConfig

model_path = self.config.actor_rollout_ref.model.path
model_cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
# Newer Qwen3 nests MoE fields under ``text_config``; older configs keep them
# at the top level. ``... or 0`` guards against fields explicitly set to None.
text_cfg = getattr(model_cfg, "text_config", None) or model_cfg
num_layers = int(getattr(text_cfg, "num_hidden_layers", 0) or 0) or int(
getattr(model_cfg, "num_hidden_layers", 0) or 0
)
top_k = int(getattr(text_cfg, "num_experts_per_tok", 0) or 0) or int(
getattr(model_cfg, "num_experts_per_tok", 0) or 0
)
cls._routing_replay_shape = (num_layers, top_k) if num_layers > 0 and top_k > 0 else None
cls._routing_replay_resolved = True
self.logger.info(f"routed_experts replay shape resolved: {cls._routing_replay_shape}")
return cls._routing_replay_shape

def _save_interaction_result(self, interaction_result: dict):
self.output_dir.mkdir(parents=True, exist_ok=True)
# rollout_cache: binary pickle for fast I/O (no readability needed)
Expand Down
Loading