Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions examples/agent_train/train_qwen3_moe_rc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#!/usr/bin/env bash
set -xeuo pipefail

project_name=${PROJECT_NAME:-'Uni-Agent-Qwen3-Coder-30B-megatron'}
exp_name=${EXP_NAME:-"$(date +%Y%m%d%H)_exp"}

RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-Coder-30B-A3B-Instruct"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/swe_agent/swe_rebench_v2_modal.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/swe_agent/swe_bench_verified.parquet"}
RUNTIME_ENV=${RUNTIME_ENV:-"${RAY_DATA_HOME}/data/swe_agent/runtime_env.yaml"}
# Must be launched from the repository root so Ray packages both `verl/` and `uni_agent/`.
AGENT_CONFIG_PATH=${AGENT_CONFIG_PATH:-"${RAY_DATA_HOME}/data/swe_agent/agent_config.yaml"}

rollout_mode=${ROLLOUT_MODE:-"async"}
rollout_name=${ROLLOUT_NAME:-"vllm"} # sglang or vllm

# Algorithm parameters
adv_estimator=${ADV_ESTIMATOR:-grpo}

use_kl_in_reward=${USE_KL_IN_REWARD:-False}
kl_coef=${KL_COEF:-0.0}
use_kl_loss=${USE_KL_LOSS:-False}
kl_loss_coef=${KL_LOSS_COEF:-0.0}

clip_ratio_low=${CLIP_RATIO_LOW:-4e-4}
clip_ratio_high=${CLIP_RATIO_HIGH:-4e-4}

# Response length parameters
max_prompt_length=${MAX_PROMPT_LENGTH:-$((1024 * 8))}
max_response_length=${MAX_RESPONSE_LENGTH:-$((1024 * 128))}
enable_overlong_buffer=${ENABLE_OVERLONG_BUFFER:-False}
overlong_buffer_len=${OVERLONG_BUFFER_LEN:-$((1024 * 4))} # unused
overlong_penalty_factor=${OVERLONG_PENALTY_FACTOR:-1.0}

loss_agg_mode=${LOSS_AGG_MODE:-"token-mean"}
loss_mode=${LOSS_MODE:-gspo}

# Algorithm
temperature=${TEMPERATURE:-1.0}
top_p=${TOP_P:-1.0}
top_k=${TOP_K:--1}
val_temperature=${VAL_TEMPERATURE:-1.0}
val_top_p=${VAL_TOP_P:-0.95}
val_top_k=${VAL_TOP_K:--1}

# Performance Related Parameter
use_dynamic_bsz=${USE_DYNAMIC_BSZ:-True}
offload=${OFFLOAD:-True}
gen_tp=${GEN_TP:-4}
train_tp=${TP:-4}
train_pp=${PP:-2}
train_cp=${CP:-4}
train_ep=${EP:-8}
train_etp=${ETP:-1}
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / train_cp))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / train_cp))

optimizer_offload_fraction=${OFFLOAD_FRACTION:-1.0}

# install mbridge
# pip3 install git+https://github.com/ISEEKYAN/mbridge
USE_MBRIDGE=${USE_MBRIDGE:-True}
USE_DIST_CKPT=${USE_DIST_CKPT:-False}

# Fully async specific parameters
NNODES_ROLLOUT=${NNODES_ROLLOUT:-12}
NNODES_TRAIN=${NNODES_TRAIN:-4}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}

train_prompt_bsz=0
n_resp_per_prompt=${N_RESP_PER_PROMPT:-16}
train_prompt_mini_bsz=${PPO_MINI_BATCH_SIZE:-8}
total_rollout_steps=${TOTAL_ROLLOUT_STEPS:-10000}
test_freq=${TEST_FREQ:-10}
staleness_threshold=${STALENESS_THRESHOLD:-1.0}
trigger_parameter_sync_step=${TRIGGER_PARAMETER_SYNC_STEP:-4}
require_batches=${REQUIRE_BATCHES:-1}
partial_rollout=${PARTIAL_ROLLOUT:-True}

# ============================================================================
# Decoupled PPO (bypass_mode=False) + Rollout Correction (Rollout IS)
# ============================================================================
bypass_mode=${BYPASS_MODE:-False} # False => decoupled PPO (recompute old_log_prob as proximal anchor)
rollout_is=${ROLLOUT_IS:-token} # token | sequence | null (IS aggregation level)
rollout_is_threshold=${ROLLOUT_IS_THRESHOLD:-2.0} # single float => TIS upper clamp; "lo_hi" string => IcePop
rollout_is_batch_normalize=${ROLLOUT_IS_BATCH_NORMALIZE:-False} # normalize IS weights to mean=1.0 within a batch
rollout_rs=${ROLLOUT_RS:-seq_mean_k1} # seq_mean_k1 | seq_mean_k3 | token_k1 | null
rollout_rs_threshold=${ROLLOUT_RS_THRESHOLD:-"0.999_1.001"} # k1: "lo_hi" ratio band; k3: single upper bound

# ============================================================================
# 30B MoE Router Replay (R3)
# ============================================================================
router_replay_mode=${ROUTER_REPLAY_MODE:-R3} # disabled | R2 | R3
enable_rollout_routing_replay=${ENABLE_ROLLOUT_ROUTING_REPLAY:-True} # required for R3 (rollout-side replay)

ray job submit --no-wait --runtime-env $RUNTIME_ENV \
-- python3 -m verl.experimental.fully_async_policy.fully_async_main \
--config-name='fully_async_ppo_megatron_trainer.yaml' \
hydra.searchpath=[pkg://verl.trainer.config] \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
data.return_raw_chat=True \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
+actor_rollout_ref.model.override_config.model_config.max_position_embeddings=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.model.use_fused_kernels=False \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_decay_style='constant' \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.optim.lr_decay_steps=${total_rollout_steps} \
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \
actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \
actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.actor.megatron.context_parallel_size=${train_cp} \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${train_etp} \
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type="alltoall" \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
algorithm.rollout_correction.bypass_mode=${bypass_mode} \
algorithm.rollout_correction.rollout_is=${rollout_is} \
algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \
algorithm.rollout_correction.rollout_is_batch_normalize=${rollout_is_batch_normalize} \
algorithm.rollout_correction.rollout_rs=${rollout_rs} \
algorithm.rollout_correction.rollout_rs_threshold="${rollout_rs_threshold}" \
actor_rollout_ref.actor.router_replay.mode=${router_replay_mode} \
actor_rollout_ref.rollout.enable_rollout_routing_replay=${enable_rollout_routing_replay} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+actor_rollout_ref.actor.checkpoint.save_contents=['hf_model'] \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.multi_turn.enable=True \
actor_rollout_ref.rollout.multi_turn.max_parallel_calls=1 \
actor_rollout_ref.rollout.agent.num_workers=8 \
actor_rollout_ref.rollout.agent.agent_loop_config_path=${AGENT_CONFIG_PATH} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.max_model_len=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.hybrid_engine=False \
actor_rollout_ref.nccl_timeout=9600 \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=True \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.ref.megatron.context_parallel_size=${train_cp} \
actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \
actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${train_etp} \
reward.reward_manager.name=dapo \
+reward.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward.reward_kwargs.overlong_buffer_cfg.log=False \
+reward.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.save_freq=10 \
trainer.total_epochs=10 \
trainer.resume_mode=auto \
trainer.log_val_generations=10 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
trainer.test_freq="${test_freq}" \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}"
Loading
Loading