diff --git a/examples/swe_agent_235b/README.md b/examples/swe_agent_235b/README.md new file mode 100644 index 00000000..079a4652 --- /dev/null +++ b/examples/swe_agent_235b/README.md @@ -0,0 +1,92 @@ +# SWE-bench Agent Training Example + +End-to-end recipe for training a **SWE-bench coding agent** with the **Uni-Agent** framework, using **fully-async RL** (Megatron actors + vLLM rollout replicas on separate nodes) and **Modal swe-rex sandboxes** for safe, parallel code execution during rollout. + +The reference configuration trains **Qwen3-235B-A22B-Instruct-2507** with GRPO, but the launch script is fully parameterized — point it at a smaller model and shrink the topology to reproduce on fewer GPUs. + +The agent solves each SWE-bench task by iteratively calling three tools inside a per-task Modal sandbox: + +- `str_replace_editor` — view / edit repository files +- `execute_bash` — run shell commands (build, run tests, inspect) +- `submit` — submit the final patch for evaluation + +The reward is computed by running the task's test suite against the submitted patch (`uni_agent.reward.swe_rebench` for training, `uni_agent.reward.swe_bench` for SWE-bench Verified). + +--- + +## Prerequisites + +- A Ray cluster with GPU nodes (the reference uses 12 nodes × 4 GPU: 8 train + 4 rollout). A working verl + Megatron + vLLM install on every node. +- A [Modal](https://modal.com) account and API token. The rollout spins up one swe-rex sandbox per in-flight trajectory, so size your concurrency against your Modal workspace's sandbox quota (see `agent_config.yaml`). +- A Weights & Biases account (or change `trainer.logger` in the launch script). + +## Step 1: Prepare the datasets + +Build the train (SWE-reBench) and validation (SWE-bench Verified) parquet files with the existing preprocessing scripts: + +```bash +python examples/data_preprocess/swe_rebench.py --local-save-dir ~/data/swe_agent +python examples/data_preprocess/swe_bench_verified.py --local-save-dir ~/data/swe_agent +``` + +These write `swe_rebench_filtered_*.parquet` and `swe_bench_verified_*.parquet` into `--local-save-dir`. Make that directory reachable from every Ray node (shared filesystem or copied), then point `TRAIN_FILE` / `TEST_FILE` at the exact files produced (see Step 3). + +## Step 2: Configure the runtime env + +Copy `runtime_env.yaml` and fill in the placeholders: + +- `working_dir` and `PYTHONPATH` → your uni-agent and verl checkouts. +- `MODAL_TOKEN_ID` / `MODAL_TOKEN_SECRET` → your Modal token (`modal token new`, or `modal token set --profile=` for a team workspace). Alternatively leave them unset and rely on `~/.modal.toml` on every node. +- `WANDB_API_KEY` → your W&B key (or run `wandb login` on the nodes and remove it). + +The file also documents two settings worth keeping: + +- **Do not** set `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` — it is incompatible with vLLM's sleep-mode `CuMemAllocator` (pytorch/pytorch#147851). +- Set `CUDA_HOME` — some MoE expert-parallel kernels JIT-compile at runtime and need it to locate `nvcc`. + +## Step 3: Launch training + +```bash +export RAY_ADDRESS=http://:8265 +export DATA_ROOT=/path/to/data-root # holds hf-models/ and data/swe_agent/ +export TRAIN_FILE=$DATA_ROOT/data/swe_agent/swe_rebench_filtered_.parquet +export TEST_FILE=$DATA_ROOT/data/swe_agent/swe_bench_verified_.parquet + +bash examples/swe_agent_235b/train_qwen3_235b_swebench.sh +``` + +Topology and parallelism are env-overridable, e.g.: + +```bash +NNODES_TRAIN=8 NNODES_ROLLOUT=4 NGPUS_PER_NODE=4 \ +ACTOR_TP=2 ACTOR_CP=2 ACTOR_PP=8 ACTOR_EP=4 ACTOR_ETP=1 \ +INFER_TP=4 \ +bash examples/swe_agent_235b/train_qwen3_235b_swebench.sh +``` + +Notable settings baked into the script (see its header for the full rationale): + +- `max_response_length=128K` — SWE-bench trajectories are long (empirically mean ~70K tokens, ~90 turns); a 32K cap truncates roughly half of them. +- `tool_parser: hermes` (in `agent_config.yaml`) — Qwen3-235B-A22B uses the Hermes tool-call template; the wrong parser silently breaks tool calls. +- `moe_token_dispatcher_type=alltoall` — portable MoE dispatch (no extra expert-parallel comm library required). +- `VLLM_USE_DEEP_GEMM=0` — works around a vLLM 0.21 EP/CUTLASS init issue. +- `performance_mode=interactivity` — favoured throughput in our rollout concurrency sweep for this model. + +## Step 4: Monitor + +- W&B: reward curve under the configured `project_name` / `experiment_name`. +- Optional Prometheus rollout metrics: set `ENABLE_PROMETHEUS_MONITORING=true` and `PROMETHEUS_CONFIG_FILE=...`; verl rewrites the scrape targets to the live vLLM replicas automatically. +- Per-trajectory agent logs: `log_dir` in `agent_config.yaml` (default `/tmp/swe_agent_rollout_logs//run.log`). + +## Tuning notes + +- **Rollout concurrency** (`concurrency` in `agent_config.yaml`) is the main throughput/stability knob. Too high vs. the vLLM KV budget causes a preemption cascade; too high vs. your Modal quota causes sandbox-create failures. Start around `20 × (rollout replicas)` and ramp up once steady. +- **Checkpoint storage**: `save_freq=1` + `max_actor_ckpt_to_keep=2` keeps only the two most recent checkpoints; raise `save_freq` if I/O-bound. + +## Files + +| File | Purpose | +|---|---| +| `train_qwen3_235b_swebench.sh` | Ray job submit + full GRPO / Megatron / vLLM config | +| `agent_config.yaml` | UniAgentLoop config: tools, Modal deployment, concurrency, reward | +| `runtime_env.yaml` | Ray runtime env template (fill in tokens / paths) | diff --git a/examples/swe_agent_235b/agent_config.yaml b/examples/swe_agent_235b/agent_config.yaml new file mode 100644 index 00000000..0f6e9a7d --- /dev/null +++ b/examples/swe_agent_235b/agent_config.yaml @@ -0,0 +1,52 @@ +# Agent-loop config for SWE-bench RL training with Modal swe-rex sandboxes. +# +# Referenced from the launch script via +# actor_rollout_ref.rollout.agent.agent_loop_config_path= +# +# The `concurrency` field is the total number of in-flight trajectories across +# the whole rollout fleet (it is divided by rollout.agent.num_workers to get a +# per-worker semaphore). Tune it against your rollout KV budget and Modal +# sandbox quota: +# - Too high relative to KV cache -> vLLM preemption cascade at high KV usage. +# - Too high relative to your Modal workspace cap -> sandbox-create failures. +# A safe starting point is ~20 x (number of rollout replicas); ramp up once the +# run is steady with no preemption. SWE-bench trajectories are long +# (max_response up to 128K), so leave generous headroom for the long tail. +- name: swe_agent + + _target_: uni_agent.agent_loop.UniAgentLoop + + concurrency: 80 + log_dir: /tmp/swe_agent_rollout_logs + mask_abnormal_exit_traj: false + # Qwen3-235B-A22B uses the Hermes tool-call template. Using the wrong parser + # silently mis-parses tool calls and breaks training — match the parser to + # your model's chat template. + tool_parser: hermes + + interaction: + action_timeout: 300 + max_turns: 300 + + env: + deployment: + type: modal + startup_timeout: 300 + runtime_timeout: 300 + deployment_timeout: 3600 + env_variables: + PIP_PROGRESS_BAR: "off" + PIP_CACHE_DIR: "~/.cache/pip" + PAGER: "cat" + MANPAGER: "cat" + LESS: "-R" + TQDM_DISABLE: "1" + GIT_PAGER: "cat" + + tools: + - name: str_replace_editor + - name: execute_bash + - name: submit + + reward: + eval_timeout: 300 diff --git a/examples/swe_agent_235b/runtime_env.yaml b/examples/swe_agent_235b/runtime_env.yaml new file mode 100644 index 00000000..eae62a7e --- /dev/null +++ b/examples/swe_agent_235b/runtime_env.yaml @@ -0,0 +1,51 @@ +# Ray runtime env for the SWE-bench fully-async RL training example. +# +# This is a TEMPLATE. Fill in the placeholders (<...>) before launching, or +# export the corresponding variables in your shell and drop them here. +# +# `working_dir` is uploaded to every Ray worker, so point it at your uni-agent +# checkout. PYTHONPATH must include both your verl checkout and uni-agent so the +# `verl.experimental.fully_async_policy` entrypoint and `uni_agent.agent_loop` +# resolve on the workers. + +working_dir: +excludes: ["/.git/", "/.venv/", "/__pycache__/"] + +pip: + - loguru + - pydantic + - pydantic_settings + - swebench + - modal + - swe-rex + - boto3 + - aiohttp + +env_vars: + PYTHONPATH: ":" + TORCH_NCCL_AVOID_RECORD_STREAMS: "1" + CUDA_DEVICE_MAX_CONNECTIONS: "1" + # Some MoE EP kernels JIT-compile at runtime and read $CUDA_HOME to locate + # nvcc. If your container leaves it unset, the compile fails with + # "/bin/nvcc not found"; set it explicitly. + CUDA_HOME: "/usr/local/cuda" + # IMPORTANT: do NOT set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True. + # vLLM's CuMemAllocator (sleep mode / weight transfer) is incompatible with + # expandable segments and will assert at startup + # (vllm/device_allocator/cumem.py). See pytorch/pytorch#147851. + # + # NCCL settings below are tuned for NVLink-rich multi-GPU nodes (e.g. GB200 + # NVL72). Adjust or remove for other fabrics. + NCCL_CUMEM_ENABLE: "1" + NCCL_NVLS_ENABLE: "1" + NCCL_MNNVL_ENABLE: "1" + VLLM_USE_NCCL_SYMM_MEM: "1" + # Modal credentials for the swe-rex sandboxes used by the SWE-bench rollout. + # Create a token with `modal token new` (or `modal token set --profile=` + # for a team workspace) and paste the id/secret here, or leave these unset and + # rely on ~/.modal.toml on every Ray node. + MODAL_TOKEN_ID: "" + MODAL_TOKEN_SECRET: "" + # Weights & Biases. Prefer `wandb login` on the nodes; only set this if you + # must pass the key through the runtime env. + WANDB_API_KEY: "" diff --git a/examples/swe_agent_235b/train_qwen3_235b_swebench.sh b/examples/swe_agent_235b/train_qwen3_235b_swebench.sh new file mode 100755 index 00000000..6b6e6802 --- /dev/null +++ b/examples/swe_agent_235b/train_qwen3_235b_swebench.sh @@ -0,0 +1,306 @@ +#!/usr/bin/env bash +# Qwen3-235B-A22B-Instruct-2507 GRPO training on SWE-bench with uni-agent rollout. +# +# Fully-async RL: a disaggregated train/rollout topology where Megatron actors +# and vLLM rollout replicas run on separate nodes and communicate via verl's +# fully_async_policy. Rollout trajectories are generated by uni-agent's +# UniAgentLoop driving Modal swe-rex sandboxes (one sandbox per SWE-bench task). +# +# Reference topology (12 nodes x 4 GPU; tune NNODES_* / parallelism to your cluster): +# ACTOR (train, 8 nodes x 4 GPU = 32 GPU, DP=1): TP=2 CP=2 PP=8 EP=4 ETP=1 +# ROLLOUT(inference,4 nodes x 4 GPU = 4 vLLM replicas): INFER_TP=4 +# +# Model: Qwen3-235B-A22B-Instruct-2507. The Instruct-2507 variant has +# max_position_embeddings=262144 (256K native), so 128K context works without +# YaRN. (The base Qwen3-235B-A22B has only 40K native and needs YaRN wiring on +# both the vLLM and Megatron sides.) +# +# Stack: +# - data: SWE-bench parquet produced by examples/data_preprocess/swe_rebench.py +# (train) and swe_bench_verified.py (val) +# - agent: uni_agent.agent_loop.UniAgentLoop (Modal swe-rex sandboxes) +# - reward: vanilla GRPO (uni_agent.reward.swe_rebench / swe_bench) +# - context: max_prompt 4K, max_response 128K +# +# vLLM rollout notes (vLLM 0.21.x): +# - performance_mode=interactivity favoured throughput in our N=32 concurrency sweep +# +# All paths/addresses/topology are overridable via the env vars below. + +set -xeuo pipefail +export CUDA_DEVICE_MAX_CONNECTIONS=1 +# vLLM 0.21 EP/CUTLASS init workaround. +export VLLM_USE_DEEP_GEMM=0 + +# ================= ray + experiment naming ================= +RAY_ADDRESS=${RAY_ADDRESS:-http://127.0.0.1:8265} + +# ================= data / model ================= +# DATA_ROOT holds the HF model cache (hf-models/hub/...) and the SWE-bench +# parquet files. Point these at your own storage. +DATA_ROOT=${DATA_ROOT:-/path/to/data-root} +EXAMPLE_DIR=${EXAMPLE_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)} +# Qwen3-235B-A22B-Instruct-2507 snapshot (first/only snapshot under the HF cache). +MODEL_PATH=${MODEL_PATH:-$(ls -d ${DATA_ROOT}/hf-models/hub/models--Qwen--Qwen3-235B-A22B-Instruct-2507/snapshots/*/ 2>/dev/null | head -1)} +MODEL_PATH=${MODEL_PATH%/} +TRAIN_FILE=${TRAIN_FILE:-${DATA_ROOT}/data/swe_agent/swe_rebench_filtered.parquet} +TEST_FILE=${TEST_FILE:-${DATA_ROOT}/data/swe_agent/swe_bench_verified.parquet} +AGENT_CONFIG_PATH=${AGENT_CONFIG_PATH:-${EXAMPLE_DIR}/agent_config.yaml} +RUNTIME_ENV=${RUNTIME_ENV:-${EXAMPLE_DIR}/runtime_env.yaml} + +project_name=${PROJECT_NAME:-'Qwen3-235B-A22B-Instruct-2507-grpo'} +exp_name=${EXP_NAME:-"qwen3_235b_swebench_$(date +%Y%m%d_%H%M)"} +CKPTS_DIR=${CKPTS_DIR:-./ckpts/${project_name}/${exp_name}} +mkdir -p "${CKPTS_DIR}" + +# ================= algorithm ================= +adv_estimator=grpo +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.001 +kl_loss_type=low_var_kl +clip_ratio_low=0.2 +clip_ratio_high=0.28 +max_prompt_length=$((1024 * 4)) +# SWE-bench trajectories are long (mean response ~70K tokens, ~90 turns); a 32K +# cap truncates roughly half of them. Start at 128K. +max_response_length=$((1024 * 128)) +loss_mode="vanilla" +loss_agg_mode="token-mean" +temperature=1.0 +top_p=1.0 +val_top_p=0.95 + +# ================= parallelism ================= +INFER_TP=${INFER_TP:-4} +INFER_EP=${INFER_EP:-1} +# Leave headroom for the checkpoint-engine weight-sync buckets (~2 GiB/bucket). +INFER_MEM_UTILIZATION=${INFER_MEM_UTILIZATION:-0.9} +update_weights_bucket_megabytes=2048 + +ACTOR_TP=${ACTOR_TP:-2} +ACTOR_CP=${ACTOR_CP:-2} +# PP=8: 94 layers = 11 + 6x12 + 11 (first/last stage = 11, satisfies Megatron's +# odd-layer PP constraint). Adjust to your node count. +ACTOR_PP=${ACTOR_PP:-8} +ACTOR_VPP=${ACTOR_VPP:-null} +ACTOR_EP=${ACTOR_EP:-4} +ACTOR_ETP=${ACTOR_ETP:-1} + +# Megatron-side offload. At this topology the per-rank weight footprint is small +# enough to keep on GPU; rely on verl's precision-aware optimizer CPU offload. +COMMON_OFFLOAD=${COMMON_OFFLOAD:-False} +PARAM_OFFLOAD=${PARAM_OFFLOAD:-$COMMON_OFFLOAD} +GRAD_OFFLOAD=${GRAD_OFFLOAD:-$COMMON_OFFLOAD} +OPTIMIZER_OFFLOAD=${OPTIMIZER_OFFLOAD:-$COMMON_OFFLOAD} +optimizer_cpu_offload=True +optimizer_offload_fraction=${OFFLOAD_FRACTION:-1.} + +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / ACTOR_CP)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / ACTOR_CP)) +train_ppo_micro_batch_size_per_gpu=2 +infer_ppo_micro_batch_size_per_gpu=2 + +# ================= async policy ================= +rollout_name="vllm" + +NNODES_ROLLOUT=${NNODES_ROLLOUT:-4} +NNODES_TRAIN=${NNODES_TRAIN:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-4} + +train_batch_size=0 +gen_prompt_bsz=1 +n_resp_per_prompt=8 # GRPO group size +ppo_mini_batch_size=32 +total_rollout_steps=200000 +test_freq=${TEST_FREQ:-20} +save_freq=${SAVE_FREQ:-1} +staleness_threshold=1.0 +trigger_parameter_sync_step=4 +require_batches=1 +partial_rollout=True +val_before_train=${VAL_BEFORE_TRAIN:-False} + +# Rollout importance/rejection sampling correction. +rollout_is=null +rollout_rs=seq_mean_k1 +rollout_rs_threshold="0.999_1.001" + +USE_MBRIDGE=True +VANILLA_MBRIDGE=True +USE_DIST_CKPT=False + +# ================= Prometheus monitoring (optional) ================= +ENABLE_PROMETHEUS_MONITORING=${ENABLE_PROMETHEUS_MONITORING:-false} +PROMETHEUS_PORT=${PROMETHEUS_PORT:-9090} +PROMETHEUS_CONFIG_FILE=${PROMETHEUS_CONFIG_FILE:-/tmp/prometheus.yml} +PROMETHEUS_SERVED_MODEL_NAME=${PROMETHEUS_SERVED_MODEL_NAME:-qwen3_235b_a22b} + +prometheus_params=() +if [[ "$ENABLE_PROMETHEUS_MONITORING" == "true" ]]; then + prometheus_params=( + actor_rollout_ref.rollout.prometheus.enable=True + actor_rollout_ref.rollout.prometheus.port=${PROMETHEUS_PORT} + actor_rollout_ref.rollout.prometheus.file=${PROMETHEUS_CONFIG_FILE} + actor_rollout_ref.rollout.prometheus.served_model_name=${PROMETHEUS_SERVED_MODEL_NAME} + ) + echo "[train] Prometheus monitoring ENABLED" +else + prometheus_params=(actor_rollout_ref.rollout.prometheus.enable=False) + echo "[train] Prometheus monitoring DISABLED" +fi + +# ================= MTP params ================= +# Disabled — Qwen3-235B-A22B has no native MTP head. +mtp_params=( + actor_rollout_ref.model.mtp.enable=False + actor_rollout_ref.model.mtp.enable_train=False + actor_rollout_ref.model.mtp.enable_rollout=False +) + +CHECKPOINT_CONTENTS=['model','hf_model','extra'] + +ray job submit --no-wait --address=$RAY_ADDRESS --runtime-env $RUNTIME_ENV \ + -- python3 -m verl.experimental.fully_async_policy.fully_async_main \ + --config-path=config \ + --config-name='fully_async_ppo_megatron_trainer.yaml' \ + hydra.searchpath=[pkg://verl.trainer.config] \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size="${train_batch_size}" \ + data.return_raw_chat=True \ + data.gen_batch_size=${gen_prompt_bsz} \ + +data.apply_chat_template_kwargs.thinking=True \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.prompt_length=${max_prompt_length} \ + actor_rollout_ref.rollout.response_length=${max_response_length} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.model.use_fused_kernels=False \ + actor_rollout_ref.actor.megatron.use_mbridge=${USE_MBRIDGE} \ + actor_rollout_ref.actor.megatron.vanilla_mbridge=${VANILLA_MBRIDGE} \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.actor.megatron.use_remove_padding=True \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.kl_loss_type=${kl_loss_type} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.lr_decay_style='constant' \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.lr_decay_steps=${total_rollout_steps} \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \ + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=${optimizer_cpu_offload} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.megatron.param_offload=${PARAM_OFFLOAD} \ + actor_rollout_ref.actor.megatron.grad_offload=${GRAD_OFFLOAD} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${OPTIMIZER_OFFLOAD} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${ACTOR_PP} \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=${ACTOR_VPP} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${ACTOR_EP} \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ACTOR_ETP} \ + actor_rollout_ref.actor.megatron.context_parallel_size=${ACTOR_CP} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.checkpoint.async_save=False \ + actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \ + actor_rollout_ref.rollout.gpu_memory_utilization=${INFER_MEM_UTILIZATION} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${INFER_TP} \ + actor_rollout_ref.rollout.expert_parallel_size=${INFER_EP} \ + actor_rollout_ref.rollout.max_model_len=$((max_prompt_length + max_response_length)) \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.performance_mode=interactivity \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.nccl_timeout=9600 \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${ACTOR_PP} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${ACTOR_TP} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${ACTOR_EP} \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${ACTOR_ETP} \ + actor_rollout_ref.ref.megatron.context_parallel_size=${ACTOR_CP} \ + actor_rollout_ref.ref.megatron.param_offload=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=11 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=11 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type="alltoall" \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_aux_loss_coeff=0.01 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_z_loss_coeff=0.001 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_shared_expert_overlap=False \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.multi_turn.enable=True \ + actor_rollout_ref.rollout.multi_turn.max_parallel_calls=1 \ + actor_rollout_ref.rollout.agent.num_workers=8 \ + actor_rollout_ref.rollout.agent.agent_loop_config_path=${AGENT_CONFIG_PATH} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=${update_weights_bucket_megabytes} \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.save_freq=${save_freq} \ + trainer.max_actor_ckpt_to_keep=${MAX_CKPT_TO_KEEP:-2} \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 \ + trainer.nnodes=${NNODES_TRAIN} \ + trainer.n_gpus_per_node=${NGPUS_PER_NODE} \ + rollout.nnodes=${NNODES_ROLLOUT} \ + rollout.n_gpus_per_node=${NGPUS_PER_NODE} \ + rollout.total_rollout_steps=${total_rollout_steps} \ + trainer.total_epochs=10 \ + trainer.test_freq=${test_freq} \ + trainer.val_before_train=${val_before_train} \ + async_training.staleness_threshold=${staleness_threshold} \ + async_training.trigger_parameter_sync_step=${trigger_parameter_sync_step} \ + async_training.require_batches=${require_batches} \ + async_training.partial_rollout=${partial_rollout} \ + "${mtp_params[@]}" \ + "${prometheus_params[@]}" \ + "$@"