Skip to content

[trainer,rollout,algo] feat: (MOPD, 1/3) Multi-Teacher Model and Server Managers#5834

Open
JacobHelwig wants to merge 5 commits intoverl-project:mainfrom
JacobHelwig:jhelwig/multiTeacherManager
Open

[trainer,rollout,algo] feat: (MOPD, 1/3) Multi-Teacher Model and Server Managers#5834
JacobHelwig wants to merge 5 commits intoverl-project:mainfrom
JacobHelwig:jhelwig/multiTeacherManager

Conversation

@JacobHelwig
Copy link
Copy Markdown
Collaborator

@JacobHelwig JacobHelwig commented Apr 1, 2026

What does this PR do?

Classes for managing multiple sets of teacher models and servers for multi-teacher OPD. This PR only supports a single teacher.

Testing

Tests to demonstrate no regression in single teacher OPD.

  • Green = current main
  • Blue = colocate mode, this PR
  • Red = standalone mode, this PR

Script

set -xeuo pipefail

############################ Quick Config ############################

ROLLOUT_NAME="vllm" # sglang or vllm

FAMILY="Qwen"
STUDENT_MODEL=Qwen2.5-0.5B
TEACHER_MODEL=Qwen2.5-3B-Instruct

USE_POLICY_GRADIENT=False
# DISTILLATION_LOSS_MODE="k3"
DISTILLATION_LOSS_MODE="forward_kl_topk"
USE_FUSED_KERNELS=False

# USE_POLICY_GRADIENT=True
# DISTILLATION_LOSS_MODE="k1"
# USE_FUSED_KERNELS=False

DISTILLATION_LOSS_MAX_CLAMP=10.0
DISTILLATION_LOG_PROB_MIN_CLAMP=-10.0

PROJECT_NAME='verl_on_policy_distillation_example_gsm8k'

MAX_PROMPT=256
MAX_RESPONSE_LENGTH=512
MAX_NUM_TOKENS=$(( MAX_PROMPT + MAX_RESPONSE_LENGTH + 1 ))
TRAIN_PROMPT_BSZ=128
STUDENT_MICRO_BATCH_SIZE_PER_GPU=8
STUDENT_MAX_TOKEN_LEN_PER_GPU=$(( STUDENT_MICRO_BATCH_SIZE_PER_GPU * (MAX_PROMPT + MAX_RESPONSE_LENGTH) ))
USE_DYNAMIC_BSZ=True

MODE=STANDALONE
TEACHER_RESOURCE_POOL=True
STUDENT_WORLD_SIZE=4
TEACHER_WORLD_SIZE=4

MODE=COLOCATE
TEACHER_RESOURCE_POOL=False
STUDENT_WORLD_SIZE=8
TEACHER_WORLD_SIZE=1

# export CUDA_VISIBLE_DEVICES=2,3
# MODE=COLOCATE
# TEACHER_RESOURCE_POOL=True
# STUDENT_WORLD_SIZE=1
# TEACHER_WORLD_SIZE=1

SP=1

EXP_NAME="MAIN/${MODE}/student-${STUDENT_MODEL}/teacher-${TEACHER_MODEL}/loss-${DISTILLATION_LOSS_MODE}/pg-${USE_POLICY_GRADIENT}"

ENFORCE_EAGER=True # true for faster debugging

############################ Paths ############################

gsm8k_train_path=$DATA_PATH/gsm8k/train.parquet
gsm8k_test_path=$DATA_PATH/gsm8k/test.parquet

TRAIN_FILES="['$gsm8k_train_path']"
TEST_FILES="['$gsm8k_test_path']"

############################ Parameter Groups ############################

DATA=(
    data.train_files="$TRAIN_FILES"
    data.val_files="$TEST_FILES"
    data.max_prompt_length=$MAX_PROMPT
    data.max_response_length=$MAX_RESPONSE_LENGTH
    data.train_batch_size=$TRAIN_PROMPT_BSZ
    data.filter_overlong_prompts=True
    data.truncation='error'
    data.shuffle=False
)

MODEL=(
    actor_rollout_ref.model.path="${FAMILY}/${STUDENT_MODEL}"
    actor_rollout_ref.model.enable_gradient_checkpointing=True
    actor_rollout_ref.model.use_remove_padding=True
    actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS
    actor_rollout_ref.actor.use_torch_compile=True
    actor_rollout_ref.rollout.enforce_eager=$ENFORCE_EAGER
)

DISTILLATION=(
    distillation.enabled=True
    distillation.num_workers=8
    distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
    distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
    distillation.teacher_model.nnodes=1
    distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
    distillation.teacher_model.inference.tensor_model_parallel_size=1
    distillation.teacher_model.inference.name=$ROLLOUT_NAME
    distillation.teacher_model.inference.gpu_memory_utilization=0.3
    distillation.teacher_model.inference.enforce_eager=$ENFORCE_EAGER
    distillation.teacher_model.inference.max_model_len=$MAX_NUM_TOKENS
    distillation.teacher_model.inference.max_num_batched_tokens=$MAX_NUM_TOKENS
    distillation.teacher_model.inference.max_num_seqs=$MAX_NUM_TOKENS
    distillation.distillation_loss.loss_mode=$DISTILLATION_LOSS_MODE
    distillation.distillation_loss.topk=64
    distillation.distillation_loss.use_task_rewards=False
    distillation.distillation_loss.use_policy_gradient=$USE_POLICY_GRADIENT
    distillation.distillation_loss.loss_max_clamp=$DISTILLATION_LOSS_MAX_CLAMP
    distillation.distillation_loss.log_prob_min_clamp=$DISTILLATION_LOG_PROB_MIN_CLAMP
)

STUDENT=(
    actor_rollout_ref.actor.optim.lr=1e-6
    actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_PROMPT_BSZ
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$STUDENT_MICRO_BATCH_SIZE_PER_GPU
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$STUDENT_MAX_TOKEN_LEN_PER_GPU
    actor_rollout_ref.actor.use_dynamic_bsz=$USE_DYNAMIC_BSZ
    actor_rollout_ref.actor.fsdp_config.param_offload=True
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=$SP
)

ROLLOUT=(
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$STUDENT_MICRO_BATCH_SIZE_PER_GPU
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$STUDENT_MAX_TOKEN_LEN_PER_GPU
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=$USE_DYNAMIC_BSZ
    actor_rollout_ref.rollout.tensor_model_parallel_size=1
    actor_rollout_ref.rollout.name=$ROLLOUT_NAME
    actor_rollout_ref.rollout.gpu_memory_utilization=0.3
    actor_rollout_ref.rollout.calculate_log_probs=False
    actor_rollout_ref.rollout.max_model_len=$MAX_NUM_TOKENS
    actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_TOKENS
    actor_rollout_ref.rollout.max_num_seqs=$MAX_NUM_TOKENS
    actor_rollout_ref.rollout.n=1
)

ALGORITHM=(
    algorithm.adv_estimator=grpo
    algorithm.use_kl_in_reward=False
)

TRAINER=(
    trainer.logger='["console","wandb"]'
    trainer.project_name=$PROJECT_NAME
    trainer.experiment_name=$EXP_NAME
    trainer.n_gpus_per_node=$STUDENT_WORLD_SIZE
    trainer.nnodes=1
    trainer.save_freq=200
    trainer.test_freq=5
    trainer.total_epochs=15
    trainer.val_before_train=False
    trainer.use_legacy_worker_impl=disable
    trainer.resume_mode=disable
    trainer.log_val_generations=5
)



############################ Launch ############################

python3 -m verl.trainer.main_ppo \
    --config-path=config \
    --config-name='ppo_trainer.yaml' \
    "${DATA[@]}" \
    "${ALGORITHM[@]}" \
    "${MODEL[@]}" \
    "${DISTILLATION[@]}" \
    "${ROLLOUT[@]}" \
    "${STUDENT[@]}" \
    "${TRAINER[@]}" \
    "$@"

GSM8k eval acc

image

GSM8k train acc

image

GSM8k distillation loss

image

Design & Code Changes

  1. MultiTeacherModelManager: new class that manages multiple TeacherModelManager for MOPD.
  2. AsyncTeacherLLMServerManager: add an attribute server_manager: AsyncLLMServerManager. Previously, AsyncTeacherLLMServerManager inherited from AsyncLLMServerManager, but for MOPD, it will manage one AsyncLLMServerManager per teacher.

@JacobHelwig JacobHelwig changed the title [MOPD, 1/n][trainer,rollout,algo] feat: Multi-Teacher Model Manager [trainer,rollout,algo] feat (MOPD, 1/n): Multi-Teacher Model Manager Apr 1, 2026
@JacobHelwig JacobHelwig changed the title [trainer,rollout,algo] feat (MOPD, 1/n): Multi-Teacher Model Manager [trainer,rollout,algo] feat: (MOPD, 1/n) Multi-Teacher Model Manager Apr 1, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a MultiTeacherModelManager to support multiple teacher models during distillation, refactoring the existing TeacherModelManager and AsyncTeacherLLMServerManager to accommodate this architecture. Key changes include moving configuration and tokenizer management into the server manager and introducing helper functions for async task execution. Review feedback identifies a potential issue with mixing auto_await and asyncio.run in compute_logprobs, incorrect type hints for coroutines in _run_all, and performance risks associated with redundant tokenizer loading in the AsyncTeacherLLMServerManager constructor.

Comment on lines 190 to 195
def compute_logprobs(self, data):
self.wake_up()
try:
return self._run_single(self.server_manager.compute_teacher_logprobs_batch(data))
return _run_single(self.server_manager.compute_teacher_logprobs_batch(data))
finally:
self.sleep()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The compute_logprobs method is synchronous, but it calls self.wake_up() and self.sleep(), which are decorated with @auto_await. When called synchronously, @auto_await typically runs the coroutine to completion using the current event loop or by creating a new one. However, _run_single explicitly uses asyncio.run(), which creates a new event loop and will fail if an event loop is already running in the same thread (which might happen if auto_await left a loop active).

More importantly, mixing auto_await for some calls and asyncio.run (via _run_single) for others in the same synchronous method is fragile and inefficient. It is better to wrap the entire sequence of operations into a single async method and run that once.

Suggested change
def compute_logprobs(self, data):
self.wake_up()
try:
return self._run_single(self.server_manager.compute_teacher_logprobs_batch(data))
return _run_single(self.server_manager.compute_teacher_logprobs_batch(data))
finally:
self.sleep()
def compute_logprobs(self, data):
async def _compute():
await self.wake_up()
try:
return await self.server_manager.compute_teacher_logprobs_batch(data)
finally:
await self.sleep()
return _run_single(_compute())

Comment on lines +32 to +33
async def _run_all(tasks: list[asyncio.Task]):
await asyncio.gather(*tasks)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hint for tasks in _run_all is list[asyncio.Task], but the function is called with a list of coroutines (e.g., in TeacherModelManager.wake_up). While asyncio.gather accepts coroutines, the type hint is technically incorrect and might be misleading for static analysis tools.

Suggested change
async def _run_all(tasks: list[asyncio.Task]):
await asyncio.gather(*tasks)
async def _run_all(tasks: list):
await asyncio.gather(*tasks)

Comment on lines +110 to +114
model_config = HFModelConfig(path=teacher_model_config.model_path)
text_tokenizer = model_config.tokenizer
if model_config.tokenizer is None:
raise ValueError(f"Tokenizer is required for teacher model {teacher_model_config.model_path}")
self.pad_token_id = text_tokenizer.pad_token_id
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Initializing HFModelConfig and accessing model_config.tokenizer inside __init__ may trigger redundant and expensive tokenizer loading every time an AsyncTeacherLLMServerManager is instantiated. Since this manager is initialized in the AgentLoopWorker, which can be numerous, this could lead to significant overhead and memory pressure. It is recommended to pass the pad_token_id directly or ensure the tokenizer is cached.

@JacobHelwig JacobHelwig force-pushed the jhelwig/multiTeacherManager branch from 12f4b5b to 15e2efe Compare April 1, 2026 23:34
@JacobHelwig JacobHelwig changed the title [trainer,rollout,algo] feat: (MOPD, 1/n) Multi-Teacher Model Manager [trainer,rollout,algo] feat: (MOPD, 1/3) Multi-Teacher Model Manager Apr 2, 2026
@JacobHelwig JacobHelwig changed the title [trainer,rollout,algo] feat: (MOPD, 1/3) Multi-Teacher Model Manager [trainer,rollout,algo] feat: (MOPD, 1/3) Multi-Teacher Model and Server Managers Apr 2, 2026
@wuxibin89 wuxibin89 mentioned this pull request Apr 7, 2026
34 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant