From cdcd6a3c3d250b59495c8247aeeae93db15745b3 Mon Sep 17 00:00:00 2001 From: JacobHelwig Date: Fri, 27 Mar 2026 01:01:19 -0500 Subject: [PATCH 1/8] Multi-teacher cfg --- .../run_qwen3_vl_geo3k.sh | 7 +- .../run_qwen_gsm8k.sh | 7 +- .../run_qwen_gsmk8_megatron.sh | 7 +- verl/experimental/agent_loop/agent_loop.py | 6 +- .../teacher_loop/teacher_manager.py | 11 +- .../teacher_loop/teacher_model.py | 10 +- .../config/distillation/distillation.yaml | 26 ++- verl/trainer/main_ppo.py | 22 ++- verl/trainer/ppo/ray_trainer.py | 2 +- verl/workers/config/distillation.py | 169 +++++++++++------- 10 files changed, 156 insertions(+), 111 deletions(-) diff --git a/examples/on_policy_distillation_trainer/run_qwen3_vl_geo3k.sh b/examples/on_policy_distillation_trainer/run_qwen3_vl_geo3k.sh index 8fbfa0f6de3..2be487d8f80 100644 --- a/examples/on_policy_distillation_trainer/run_qwen3_vl_geo3k.sh +++ b/examples/on_policy_distillation_trainer/run_qwen3_vl_geo3k.sh @@ -75,10 +75,9 @@ MODEL=( DISTILLATION=( distillation.enabled=True - distillation.num_workers=8 - distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL - distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE - distillation.teacher_model.nnodes=1 + distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL + distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE + distillation.nnodes=1 distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}" distillation.teacher_model.inference.tensor_model_parallel_size=1 distillation.teacher_model.inference.name=$ROLLOUT_NAME diff --git a/examples/on_policy_distillation_trainer/run_qwen_gsm8k.sh b/examples/on_policy_distillation_trainer/run_qwen_gsm8k.sh index 5f0901052ec..6a2f9365409 100644 --- a/examples/on_policy_distillation_trainer/run_qwen_gsm8k.sh +++ b/examples/on_policy_distillation_trainer/run_qwen_gsm8k.sh @@ -74,10 +74,9 @@ MODEL=( DISTILLATION=( distillation.enabled=True - distillation.num_workers=8 - distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL - distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE - distillation.teacher_model.nnodes=1 + distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL + distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE + distillation.nnodes=1 distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}" distillation.teacher_model.inference.tensor_model_parallel_size=1 distillation.teacher_model.inference.name=$ROLLOUT_NAME diff --git a/examples/on_policy_distillation_trainer/run_qwen_gsmk8_megatron.sh b/examples/on_policy_distillation_trainer/run_qwen_gsmk8_megatron.sh index 40db99c43ef..9032c402225 100644 --- a/examples/on_policy_distillation_trainer/run_qwen_gsmk8_megatron.sh +++ b/examples/on_policy_distillation_trainer/run_qwen_gsmk8_megatron.sh @@ -83,10 +83,9 @@ MODEL=( DISTILLATION=( distillation.enabled=True - distillation.num_workers=8 - distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL - distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE - distillation.teacher_model.nnodes=1 + distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL + distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE + distillation.nnodes=1 distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}" distillation.teacher_model.inference.tensor_model_parallel_size=1 distillation.teacher_model.inference.name=$ROLLOUT_NAME diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 0d01ec610d9..fbc26ad3846 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -433,7 +433,7 @@ def __init__( if self.distillation_enabled: self.distillation_config: DistillationConfig = omega_conf_to_dataclass(self.distillation_config) self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss - self.stream_teacher_with_rollout = self.distillation_config.teacher_model.enable_resource_pool + self.stream_teacher_with_rollout = self.distillation_config.enable_resource_pool if self.stream_teacher_with_rollout: if teacher_servers is None: @@ -1012,9 +1012,7 @@ def __init__( self.teacher_model_manager = teacher_model_manager self.distillation_enabled = is_distillation_enabled(self.config.get("distillation", None)) - self.stream_teacher_with_rollout = ( - self.distillation_enabled and self.config.distillation.teacher_model.enable_resource_pool - ) + self.stream_teacher_with_rollout = self.distillation_enabled and self.config.distillation.enable_resource_pool assert worker_group is not None or self.rollout_config.nnodes > 0, "nnodes must be > 0 in standalone mode" diff --git a/verl/experimental/teacher_loop/teacher_manager.py b/verl/experimental/teacher_loop/teacher_manager.py index 28a89262e96..73ceff79971 100644 --- a/verl/experimental/teacher_loop/teacher_manager.py +++ b/verl/experimental/teacher_loop/teacher_manager.py @@ -25,21 +25,21 @@ from verl.protocol import DataProto from verl.utils.config import omega_conf_to_dataclass from verl.utils.tokenizer import normalize_token_ids -from verl.workers.config import DistillationConfig, DistillationLossConfig +from verl.workers.config import DistillationConfig, DistillationLossConfig, DistillationTeacherModelConfig def _get_teacher_sampling_params( - distillation_config: DistillationConfig, + teacher_model_config: DistillationTeacherModelConfig, distillation_loss_config: DistillationLossConfig, ) -> dict[str, Any]: """Get sampling parameters for teacher model when computing log probabilities for distillation.""" - if distillation_config.teacher_model.inference.temperature != 1.0: + if teacher_model_config.inference.temperature != 1.0: raise NotImplementedError("vLLM does not support temperature for prompt_logprobs.") num_logprobs = distillation_loss_config.topk if distillation_loss_config.loss_settings.use_topk else 0 return { "max_tokens": 1, - "temperature": distillation_config.teacher_model.inference.temperature, + "temperature": teacher_model_config.inference.temperature, "prompt_logprobs": num_logprobs, } @@ -102,6 +102,7 @@ def __init__( else: self.distillation_config: DistillationConfig = omega_conf_to_dataclass(distillation_config) self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss + self.teacher_model_config = self.distillation_config.get_single_teacher_model() self.pad_token_id = pad_token_id async def compute_teacher_logprobs_single( @@ -114,7 +115,7 @@ async def compute_teacher_logprobs_single( teacher_output = await self.generate( request_id=uuid4().hex, prompt_ids=sequence_ids, - sampling_params=_get_teacher_sampling_params(self.distillation_config, self.distillation_loss_config), + sampling_params=_get_teacher_sampling_params(self.teacher_model_config, self.distillation_loss_config), image_data=multi_modal_data.get("images"), video_data=multi_modal_data.get("videos"), ) diff --git a/verl/experimental/teacher_loop/teacher_model.py b/verl/experimental/teacher_loop/teacher_model.py index 966efa969ab..00dedf9dc7d 100644 --- a/verl/experimental/teacher_loop/teacher_model.py +++ b/verl/experimental/teacher_loop/teacher_model.py @@ -54,7 +54,7 @@ def __init__( self.sleep() def _initialize_llm_servers(self): - teacher_model_config: DistillationTeacherModelConfig = self.config.teacher_model + teacher_model_config: DistillationTeacherModelConfig = self.config.get_single_teacher_model() teacher_world_size = ( teacher_model_config.inference.tensor_model_parallel_size * teacher_model_config.inference.data_parallel_size @@ -63,8 +63,12 @@ def _initialize_llm_servers(self): world_size = ( self.resource_pool.world_size if self.resource_pool # colocate mode - else teacher_model_config.n_gpus_per_node * teacher_model_config.nnodes # standalone mode + else self.config.n_gpus_per_node * self.config.nnodes # standalone mode ) + if world_size % teacher_world_size != 0: + raise ValueError( + f"Teacher world size {teacher_world_size} must divide allocated resource pool size {world_size}." + ) num_replicas = world_size // teacher_world_size rollout_replica_class = get_rollout_replica_class(teacher_model_config.inference.name) @@ -80,7 +84,7 @@ def _initialize_llm_servers(self): replica_rank=replica_rank, config=rollout_config, model_config=model_config, - gpus_per_node=teacher_model_config.n_gpus_per_node, + gpus_per_node=self.config.n_gpus_per_node, is_teacher_model=True, ) for replica_rank in range(num_replicas) diff --git a/verl/trainer/config/distillation/distillation.yaml b/verl/trainer/config/distillation/distillation.yaml index 026a85a779b..173221c549b 100644 --- a/verl/trainer/config/distillation/distillation.yaml +++ b/verl/trainer/config/distillation/distillation.yaml @@ -13,9 +13,6 @@ defaults: # Whether to enable distillation. enabled: false -# we launch num_workers teacher managers to parallelize the teacher logprob computation -num_workers: 8 - # distillation loss config distillation_loss: @@ -58,21 +55,23 @@ distillation_loss: clip_ratio_high: 0.2 -# teacher model config -teacher_model: - _target_: verl.workers.config.DistillationTeacherModelConfig +# Whether to enable a separate resource pool for distillation teacher model(s) +enable_resource_pool: false - # Whether to enable separate resource pool for teacher model(s) - enable_resource_pool: false +# Number of GPUs per node in the distillation teacher resource pool +n_gpus_per_node: 8 - # Number of GPUs per node to use for distillation teacher model(s) - n_gpus_per_node: 8 +# Number of nodes in the distillation teacher resource pool +nnodes: 0 - # Number of nodes to use for distillation teacher model(s) - nnodes: 0 +# multi-teacher configs +teacher_models: {} +# single-teacher config +teacher_model: + _target_: verl.workers.config.DistillationTeacherModelConfig + task: null model_path: null - inference: _target_: verl.workers.config.RolloutConfig name: ${oc.select:actor_rollout_ref.rollout.name} @@ -94,7 +93,6 @@ teacher_model: enable_prefix_caching: true disable_log_stats: true skip_tokenizer_init: false - prompt_length: ${oc.select:actor_rollout_ref.rollout.prompt_length} response_length: ${oc.select:actor_rollout_ref.rollout.response_length} temperature: ${oc.select:actor_rollout_ref.rollout.temperature} \ No newline at end of file diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index da3267a308c..d7440d290ee 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -243,19 +243,17 @@ def init_resource_pool_mgr(self, config): distillation_config = config.get("distillation") if is_distillation_enabled(distillation_config): - if distillation_config.teacher_model.enable_resource_pool: - if distillation_config.teacher_model.n_gpus_per_node <= 0: - raise ValueError("config.distillation.teacher_model.n_gpus_per_node must be greater than 0") - if distillation_config.teacher_model.nnodes <= 0: - raise ValueError("config.distillation.teacher_model.nnodes must be greater than 0") - - teacher_pool = [ - distillation_config.teacher_model.n_gpus_per_node - ] * distillation_config.teacher_model.nnodes + if distillation_config.enable_resource_pool: + if distillation_config.n_gpus_per_node <= 0: + raise ValueError("config.distillation.n_gpus_per_node must be greater than 0") + if distillation_config.nnodes <= 0: + raise ValueError("config.distillation.nnodes must be greater than 0") + + teacher_pool = [distillation_config.n_gpus_per_node] * distillation_config.nnodes resource_pool_spec["teacher_pool"] = teacher_pool else: - distillation_config.teacher_model.nnodes = config.trainer.nnodes - distillation_config.teacher_model.n_gpus_per_node = config.trainer.n_gpus_per_node + distillation_config.nnodes = config.trainer.nnodes + distillation_config.n_gpus_per_node = config.trainer.n_gpus_per_node from verl.trainer.ppo.ray_trainer import ResourcePoolManager @@ -281,7 +279,7 @@ def add_teacher_model_resource_pool(self, config): if is_distillation_enabled(config.get("distillation")): # we do not use teacher model workers, so we only register teacher model in resource pool # without registering a teacher model worker in role-worker mapping - if config.distillation.teacher_model.enable_resource_pool: + if config.distillation.enable_resource_pool: self.mapping[Role.TeacherModel] = "teacher_pool" else: self.mapping[Role.TeacherModel] = "global_pool" diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index d1f138aa69a..37899e229d8 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -511,7 +511,7 @@ def _compute_reward_colocate(self, batch: DataProto) -> tuple[torch.Tensor, dict return batch_reward def _should_compute_teacher_colocate(self, batch: DataProto) -> bool: - return self.use_teacher_policy and not self.distillation_config.teacher_model.enable_resource_pool + return self.use_teacher_policy and not self.distillation_config.enable_resource_pool def _compute_teacher_colocate(self, batch: DataProto) -> DataProto: """Compute teacher logprobs after rollout when teacher and student are colocated.""" diff --git a/verl/workers/config/distillation.py b/verl/workers/config/distillation.py index c78cf4aad76..d3bdb831946 100644 --- a/verl/workers/config/distillation.py +++ b/verl/workers/config/distillation.py @@ -18,6 +18,7 @@ from typing import Optional from verl.base_config import BaseConfig +from verl.utils.config import omega_conf_to_dataclass from .rollout import RolloutConfig @@ -115,12 +116,8 @@ def __post_init__(self): class DistillationTeacherModelConfig(BaseConfig): """Configuration for on-policy distillation teacher. - enable_resource_pool (bool): - Whether to enable separate resource pool for teacher model(s). - n_gpus_per_node (int): - Number of GPUs per node to use for distillation teacher model(s). - nnodes (int): - Number of nodes to use for distillation teacher model(s). + task (str, optional): + Task identifier to route examples to the teacher model in future multi-teacher support. model_path (str, optional): Model path for the teacher model. Can be a local path or a Hugging Face model inference (RolloutConfig): @@ -129,12 +126,64 @@ class DistillationTeacherModelConfig(BaseConfig): _mutable_fields = BaseConfig._mutable_fields - enable_resource_pool: bool = False - n_gpus_per_node: int = 0 - nnodes: int = 0 + task: Optional[str] = None model_path: Optional[str] = None inference: RolloutConfig = field(default_factory=RolloutConfig) + def is_configured(self) -> bool: + if self.task is not None and self.model_path is None: + raise ValueError("DistillationTeacherModelConfig is misconfigured with task but no model_path.") + return self.model_path is not None + + def validate_and_prepare_for_distillation(self, use_topk: bool, topk: Optional[int]) -> None: + # Prompt + Response from student are fed into teacher as context + max_model_len = self.inference.max_model_len + max_num_batched_tokens = self.inference.max_num_batched_tokens + student_prompt_length = self.inference.prompt_length + student_response_length = self.inference.response_length + required_context_len = student_prompt_length + student_response_length + 1 + if max_model_len is not None and required_context_len > max_model_len: + raise ValueError( + "Distillation teacher inference requires room for the student prompt, the full student " + f"response, and one generated token, but got {student_prompt_length=}, " + f"{student_response_length=}, {required_context_len=}, {max_model_len=}." + ) + if max_num_batched_tokens is not None and required_context_len > max_num_batched_tokens: + raise ValueError( + "Distillation teacher inference requires room for the student prompt, the full student " + f"response, and one generated token within the engine batching budget, but got " + f"{student_prompt_length=}, {student_response_length=}, {required_context_len=}, " + f"{max_num_batched_tokens=}." + ) + + self.inference.prompt_length = self.inference.prompt_length + self.inference.response_length + self.inference.response_length = 1 + self._validate_topk_logprobs(use_topk=use_topk, topk=topk) + + def _validate_topk_logprobs(self, use_topk: bool, topk: Optional[int]) -> None: + if not use_topk or topk is None: + return + + engine_name = self.inference.name + engine_kwargs = self.inference.engine_kwargs + match engine_name: + case "vllm": + vllm_engine_kwargs = dict(engine_kwargs.get("vllm", {})) + max_logprobs = vllm_engine_kwargs.get("max_logprobs") + if max_logprobs is None: + vllm_engine_kwargs["max_logprobs"] = topk + max_logprobs = topk + if max_logprobs < topk: + raise ValueError( + f"VLLM max_logprobs ({max_logprobs}) must be >= distillation_loss topk " + f"({topk}) to enable distillation loss computation." + ) + engine_kwargs["vllm"] = vllm_engine_kwargs + case _: + raise NotImplementedError( + f"DistillationTeacherModelConfig does not support inference engine {engine_name}" + ) + @dataclass class DistillationConfig(BaseConfig): @@ -142,67 +191,67 @@ class DistillationConfig(BaseConfig): enabled (bool): Whether on-policy distillation is enabled. - num_workers (int): - Number of teacher model replicas. + enable_resource_pool (bool): + Whether to enable a separate resource pool for distillation teacher model(s). + n_gpus_per_node (int): + Number of GPUs per node in the teacher resource pool. + nnodes (int): + Number of nodes in the teacher resource pool. teacher_model (TeacherModelConfig): - Configuration for the teacher model used for distillation. + Configuration for the single teacher model used for distillation. + teacher_models (dict[str, TeacherModelConfig]): + Configurations for teacher models used for multi-teacher distillation. distillation_loss (DistillationLossConfig): Configuration for distillation loss settings. """ - _mutable_fields = BaseConfig._mutable_fields + _mutable_fields = BaseConfig._mutable_fields | {"teacher_models"} enabled: bool = False - num_workers: int = 8 + enable_resource_pool: bool = False + n_gpus_per_node: int = 0 + nnodes: int = 0 teacher_model: DistillationTeacherModelConfig = field(default_factory=DistillationTeacherModelConfig) + teacher_models: dict[str, DistillationTeacherModelConfig] = field(default_factory=dict) distillation_loss: DistillationLossConfig = field(default_factory=DistillationLossConfig) def __post_init__(self): - # Prompt + Response from student are fed into teacher as context - max_model_len = self.teacher_model.inference.max_model_len - max_num_batched_tokens = self.teacher_model.inference.max_num_batched_tokens - student_prompt_length = self.teacher_model.inference.prompt_length - student_response_length = self.teacher_model.inference.response_length - if self.enabled: - required_context_len = student_prompt_length + student_response_length + 1 - if max_model_len is not None and required_context_len > max_model_len: - raise ValueError( - "Distillation teacher inference requires room for the student prompt, the full student " - f"response, and one generated token, but got {student_prompt_length=}, " - f"{student_response_length=}, {required_context_len=}, {max_model_len=}." - ) - if max_num_batched_tokens is not None and required_context_len > max_num_batched_tokens: - raise ValueError( - "Distillation teacher inference requires room for the student prompt, the full student " - f"response, and one generated token within the engine batching budget, but got " - f"{student_prompt_length=}, {student_response_length=}, {required_context_len=}, " - f"{max_num_batched_tokens=}." - ) + if not self.enabled: + return - self.teacher_model.inference.prompt_length = ( - self.teacher_model.inference.prompt_length + self.teacher_model.inference.response_length - ) - self.teacher_model.inference.response_length = 1 + self.teacher_models = self._resolve_teacher_models() + if len(self.teacher_models) != 1: + raise NotImplementedError("`teacher_models` are not supported yet in the runtime path.") + for teacher_model in self.teacher_models.values(): + teacher_model.validate_and_prepare_for_distillation( + use_topk=self.distillation_loss.loss_settings.use_topk, + topk=self.distillation_loss.topk, + ) - # Ensure max log probs is aligned with top-k - engine_name = self.teacher_model.inference.name - engine_kwargs = self.teacher_model.inference.engine_kwargs - if not self.distillation_loss.loss_settings.use_topk or self.distillation_loss.topk is None or not self.enabled: - return - match engine_name: - case "vllm": - vllm_engine_kwargs = dict(engine_kwargs.get("vllm", {})) - max_logprobs = vllm_engine_kwargs.get("max_logprobs") - if max_logprobs is None: - vllm_engine_kwargs["max_logprobs"] = self.distillation_loss.topk - max_logprobs = self.distillation_loss.topk - if max_logprobs < self.distillation_loss.topk: - raise ValueError( - f"VLLM max_logprobs ({max_logprobs}) must be >= distillation_loss topk " - f"({self.distillation_loss.topk}) to enable distillation loss computation." - ) - engine_kwargs["vllm"] = vllm_engine_kwargs - case _: - raise NotImplementedError( - f"DistillationTeacherModelConfig does not support inference engine {engine_name}" - ) + def get_single_teacher_model(self) -> DistillationTeacherModelConfig: + if len(self.teacher_models) != 1: + raise ValueError( + f"Expected exactly one active distillation teacher config, but got {len(self.teacher_models)}." + ) + return next(iter(self.teacher_models.values())) + + def _resolve_teacher_models(self) -> dict[str, DistillationTeacherModelConfig]: + if self.teacher_model.is_configured() and self.teacher_models: + raise ValueError("Specify either distillation.teacher_model or distillation.teacher_models, not both.") + + if self.teacher_models: + teacher_models = {} + for model_name, teacher_model in self.teacher_models.items(): + teacher_model = omega_conf_to_dataclass(teacher_model, dataclass_type=DistillationTeacherModelConfig) + if teacher_model.task is None: + raise ValueError(f"distillation.teacher_models.{model_name}.task must be non-null.") + teacher_models[model_name] = teacher_model + return teacher_models + + if not self.teacher_model.is_configured(): + raise ValueError( + "Distillation is enabled but no teacher model is configured. " + "Please configure distillation.teacher_model or distillation.teacher_models." + ) + + return {"teacher": self.teacher_model} From e85de2661a2a86ed9203d8323d463b055f4e1ead Mon Sep 17 00:00:00 2001 From: JacobHelwig Date: Fri, 27 Mar 2026 01:01:56 -0500 Subject: [PATCH 2/8] Fix megatron name --- .../{run_qwen_gsmk8_megatron.sh => run_qwen_gsm8k_megatron.sh} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/on_policy_distillation_trainer/{run_qwen_gsmk8_megatron.sh => run_qwen_gsm8k_megatron.sh} (100%) diff --git a/examples/on_policy_distillation_trainer/run_qwen_gsmk8_megatron.sh b/examples/on_policy_distillation_trainer/run_qwen_gsm8k_megatron.sh similarity index 100% rename from examples/on_policy_distillation_trainer/run_qwen_gsmk8_megatron.sh rename to examples/on_policy_distillation_trainer/run_qwen_gsm8k_megatron.sh From 4ca6da7e4359df9e807705b099a250e1cc8c0414 Mon Sep 17 00:00:00 2001 From: JacobHelwig Date: Fri, 27 Mar 2026 01:06:38 -0500 Subject: [PATCH 3/8] Generate cfgs --- verl/trainer/config/_generated_ppo_megatron_trainer.yaml | 9 +++++---- .../config/_generated_ppo_torchtitan_trainer.yaml | 9 +++++---- verl/trainer/config/_generated_ppo_trainer.yaml | 9 +++++---- verl/trainer/config/_generated_ppo_veomni_trainer.yaml | 9 +++++---- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index f48133d5a2e..a34aac53a00 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -735,7 +735,6 @@ algorithm: distillation: _target_: verl.workers.config.DistillationConfig enabled: false - num_workers: 8 distillation_loss: _target_: verl.workers.config.DistillationLossConfig loss_mode: k3 @@ -749,11 +748,13 @@ distillation: clip_ratio: 0.2 clip_ratio_low: 0.2 clip_ratio_high: 0.2 + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + teacher_models: {} teacher_model: _target_: verl.workers.config.DistillationTeacherModelConfig - enable_resource_pool: false - n_gpus_per_node: 8 - nnodes: 0 + task: null model_path: null inference: _target_: verl.workers.config.RolloutConfig diff --git a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml index fe1a87fb75b..2706cc9796e 100644 --- a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml @@ -628,7 +628,6 @@ algorithm: distillation: _target_: verl.workers.config.DistillationConfig enabled: false - num_workers: 8 distillation_loss: _target_: verl.workers.config.DistillationLossConfig loss_mode: k3 @@ -642,11 +641,13 @@ distillation: clip_ratio: 0.2 clip_ratio_low: 0.2 clip_ratio_high: 0.2 + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + teacher_models: {} teacher_model: _target_: verl.workers.config.DistillationTeacherModelConfig - enable_resource_pool: false - n_gpus_per_node: 8 - nnodes: 0 + task: null model_path: null inference: _target_: verl.workers.config.RolloutConfig diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 6ff7450b19e..18d695596a2 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -660,7 +660,6 @@ algorithm: distillation: _target_: verl.workers.config.DistillationConfig enabled: false - num_workers: 8 distillation_loss: _target_: verl.workers.config.DistillationLossConfig loss_mode: k3 @@ -674,11 +673,13 @@ distillation: clip_ratio: 0.2 clip_ratio_low: 0.2 clip_ratio_high: 0.2 + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + teacher_models: {} teacher_model: _target_: verl.workers.config.DistillationTeacherModelConfig - enable_resource_pool: false - n_gpus_per_node: 8 - nnodes: 0 + task: null model_path: null inference: _target_: verl.workers.config.RolloutConfig diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml index 8537a537a60..b7b23172a09 100644 --- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml @@ -605,7 +605,6 @@ algorithm: distillation: _target_: verl.workers.config.DistillationConfig enabled: false - num_workers: 8 distillation_loss: _target_: verl.workers.config.DistillationLossConfig loss_mode: k3 @@ -619,11 +618,13 @@ distillation: clip_ratio: 0.2 clip_ratio_low: 0.2 clip_ratio_high: 0.2 + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + teacher_models: {} teacher_model: _target_: verl.workers.config.DistillationTeacherModelConfig - enable_resource_pool: false - n_gpus_per_node: 8 - nnodes: 0 + task: null model_path: null inference: _target_: verl.workers.config.RolloutConfig From 95097bb4447027818a972e943ad595ea0735a233 Mon Sep 17 00:00:00 2001 From: JacobHelwig Date: Fri, 27 Mar 2026 01:36:28 -0500 Subject: [PATCH 4/8] Add task and model path checks --- verl/workers/config/distillation.py | 35 +++++++++++++++++------------ 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/verl/workers/config/distillation.py b/verl/workers/config/distillation.py index d3bdb831946..047f939c127 100644 --- a/verl/workers/config/distillation.py +++ b/verl/workers/config/distillation.py @@ -130,10 +130,13 @@ class DistillationTeacherModelConfig(BaseConfig): model_path: Optional[str] = None inference: RolloutConfig = field(default_factory=RolloutConfig) - def is_configured(self) -> bool: - if self.task is not None and self.model_path is None: - raise ValueError("DistillationTeacherModelConfig is misconfigured with task but no model_path.") - return self.model_path is not None + def is_configured(self, is_multi: bool) -> bool: + configured = self.model_path is not None + if self.task is not None and not configured: + raise ValueError(f"{self.task=} is set but model_path is not set for this teacher model config.") + if is_multi and configured and self.task is None: + raise ValueError("task must be specified for multi-teacher setups.") + return configured def validate_and_prepare_for_distillation(self, use_topk: bool, topk: Optional[int]) -> None: # Prompt + Response from student are fed into teacher as context @@ -220,13 +223,13 @@ def __post_init__(self): return self.teacher_models = self._resolve_teacher_models() - if len(self.teacher_models) != 1: - raise NotImplementedError("`teacher_models` are not supported yet in the runtime path.") for teacher_model in self.teacher_models.values(): teacher_model.validate_and_prepare_for_distillation( use_topk=self.distillation_loss.loss_settings.use_topk, topk=self.distillation_loss.topk, ) + if len(self.teacher_models) != 1: + raise NotImplementedError("Multiple teacher models are not supported yet in the runtime path.") def get_single_teacher_model(self) -> DistillationTeacherModelConfig: if len(self.teacher_models) != 1: @@ -236,19 +239,23 @@ def get_single_teacher_model(self) -> DistillationTeacherModelConfig: return next(iter(self.teacher_models.values())) def _resolve_teacher_models(self) -> dict[str, DistillationTeacherModelConfig]: - if self.teacher_model.is_configured() and self.teacher_models: + if self.teacher_model.is_configured(is_multi=False) and self.teacher_models: raise ValueError("Specify either distillation.teacher_model or distillation.teacher_models, not both.") - if self.teacher_models: - teacher_models = {} - for model_name, teacher_model in self.teacher_models.items(): - teacher_model = omega_conf_to_dataclass(teacher_model, dataclass_type=DistillationTeacherModelConfig) - if teacher_model.task is None: - raise ValueError(f"distillation.teacher_models.{model_name}.task must be non-null.") + teacher_models = {} + for model_name, teacher_model in self.teacher_models.items(): + teacher_model = omega_conf_to_dataclass(teacher_model, dataclass_type=DistillationTeacherModelConfig) + if teacher_model.is_configured(is_multi=True): teacher_models[model_name] = teacher_model + if teacher_models: + if self.teacher_model.is_configured(is_multi=False): + raise ValueError( + "Multiple teacher models are configured in distillation.teacher_models, " + "but distillation.teacher_model is also configured." + ) return teacher_models - if not self.teacher_model.is_configured(): + if not self.teacher_model.is_configured(is_multi=False): raise ValueError( "Distillation is enabled but no teacher model is configured. " "Please configure distillation.teacher_model or distillation.teacher_models." From 309e1406b3149680bd0f342e618fda79ff62202e Mon Sep 17 00:00:00 2001 From: JacobHelwig Date: Fri, 27 Mar 2026 12:12:07 -0500 Subject: [PATCH 5/8] Compose cfg test --- scripts/mopd_cfg_tests.py | 57 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 scripts/mopd_cfg_tests.py diff --git a/scripts/mopd_cfg_tests.py b/scripts/mopd_cfg_tests.py new file mode 100644 index 00000000000..2b993ca3074 --- /dev/null +++ b/scripts/mopd_cfg_tests.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +"""Compose the 5774 multi-teacher MOPD config shape. + +Usage: + python scripts/mopd_cfg_tests.py +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from pprint import pprint + +from hydra import compose, initialize_config_dir +from omegaconf import OmegaConf + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from verl.utils.config import omega_conf_to_dataclass + + +def compose_from_overrides(config_dir: Path, override_string: str): + overrides = override_string.split() + with initialize_config_dir(version_base=None, config_dir=str(config_dir)): + return compose(config_name="ppo_trainer.yaml", overrides=overrides) + + +def main() -> int: + repo = REPO_ROOT + config_dir = repo / "verl" / "trainer" / "config" + + if not config_dir.is_dir(): + print(f"ERROR: Config directory not found: {config_dir}", file=sys.stderr) + return 2 + + multi_teacher_overrides = ( + "distillation.enabled=False " + "+distillation.teacher_models.geo3k.task=geo3k " + "+distillation.teacher_models.geo3k.model_path=path/to/geo3k_teacher " + "+distillation.teacher_models.geo3k.inference.tensor_model_parallel_size=1 " + "+distillation.teacher_models.geo3k.inference.gpu_memory_utilization=0.3 " + "+distillation.teacher_models.gsm8k.task=gsm8k " + "+distillation.teacher_models.gsm8k.model_path=path/to/gsm8k_teacher " + "+distillation.teacher_models.gsm8k.inference.tensor_model_parallel_size=1 " + "+distillation.teacher_models.gsm8k.inference.gpu_memory_utilization=0.3 " + ) + + multi_teacher_cfg = compose_from_overrides(config_dir, multi_teacher_overrides) + pprint(omega_conf_to_dataclass(multi_teacher_cfg.distillation)) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From dd0f0b94055ad4ee4f3b79685360ee94ea90a5ea Mon Sep 17 00:00:00 2001 From: JacobHelwig Date: Mon, 6 Apr 2026 21:36:43 -0500 Subject: [PATCH 6/8] PC --- scripts/mopd_cfg_tests.py | 57 ----------------------------- verl/workers/config/distillation.py | 1 - 2 files changed, 58 deletions(-) delete mode 100644 scripts/mopd_cfg_tests.py diff --git a/scripts/mopd_cfg_tests.py b/scripts/mopd_cfg_tests.py deleted file mode 100644 index 2b993ca3074..00000000000 --- a/scripts/mopd_cfg_tests.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python3 -"""Compose the 5774 multi-teacher MOPD config shape. - -Usage: - python scripts/mopd_cfg_tests.py -""" - -from __future__ import annotations - -import sys -from pathlib import Path -from pprint import pprint - -from hydra import compose, initialize_config_dir -from omegaconf import OmegaConf - -REPO_ROOT = Path(__file__).resolve().parents[1] -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) - -from verl.utils.config import omega_conf_to_dataclass - - -def compose_from_overrides(config_dir: Path, override_string: str): - overrides = override_string.split() - with initialize_config_dir(version_base=None, config_dir=str(config_dir)): - return compose(config_name="ppo_trainer.yaml", overrides=overrides) - - -def main() -> int: - repo = REPO_ROOT - config_dir = repo / "verl" / "trainer" / "config" - - if not config_dir.is_dir(): - print(f"ERROR: Config directory not found: {config_dir}", file=sys.stderr) - return 2 - - multi_teacher_overrides = ( - "distillation.enabled=False " - "+distillation.teacher_models.geo3k.task=geo3k " - "+distillation.teacher_models.geo3k.model_path=path/to/geo3k_teacher " - "+distillation.teacher_models.geo3k.inference.tensor_model_parallel_size=1 " - "+distillation.teacher_models.geo3k.inference.gpu_memory_utilization=0.3 " - "+distillation.teacher_models.gsm8k.task=gsm8k " - "+distillation.teacher_models.gsm8k.model_path=path/to/gsm8k_teacher " - "+distillation.teacher_models.gsm8k.inference.tensor_model_parallel_size=1 " - "+distillation.teacher_models.gsm8k.inference.gpu_memory_utilization=0.3 " - ) - - multi_teacher_cfg = compose_from_overrides(config_dir, multi_teacher_overrides) - pprint(omega_conf_to_dataclass(multi_teacher_cfg.distillation)) - - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/verl/workers/config/distillation.py b/verl/workers/config/distillation.py index 5307d387057..3dba1609a1d 100644 --- a/verl/workers/config/distillation.py +++ b/verl/workers/config/distillation.py @@ -139,7 +139,6 @@ def is_configured(self, is_multi: bool) -> bool: return configured def validate_and_prepare_for_distillation(self, use_topk: bool, topk: Optional[int]) -> None: - # Prompt + Response from student are fed into teacher as context max_model_len = self.inference.max_model_len student_prompt_length = self.inference.prompt_length From 9ebd62db669ea559fde607b196aef81d7edaca4f Mon Sep 17 00:00:00 2001 From: JacobHelwig Date: Mon, 6 Apr 2026 21:42:52 -0500 Subject: [PATCH 7/8] top-k check --- verl/workers/config/distillation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/verl/workers/config/distillation.py b/verl/workers/config/distillation.py index 3dba1609a1d..17d675b5a50 100644 --- a/verl/workers/config/distillation.py +++ b/verl/workers/config/distillation.py @@ -155,9 +155,12 @@ def validate_and_prepare_for_distillation(self, use_topk: bool, topk: Optional[i self._validate_topk_logprobs(use_topk=use_topk, topk=topk) def _validate_topk_logprobs(self, use_topk: bool, topk: Optional[int]) -> None: - if not use_topk or topk is None: + if not use_topk: return + if topk is None: + raise ValueError("topk must be specified when use_topk is True.") + engine_name = self.inference.name engine_kwargs = self.inference.engine_kwargs match engine_name: From 432188b6fb36aa6f2d600644c12d1111cf90d84f Mon Sep 17 00:00:00 2001 From: JacobHelwig Date: Mon, 6 Apr 2026 21:47:25 -0500 Subject: [PATCH 8/8] Correct check --- verl/workers/config/distillation.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/verl/workers/config/distillation.py b/verl/workers/config/distillation.py index 17d675b5a50..0e1db1fea16 100644 --- a/verl/workers/config/distillation.py +++ b/verl/workers/config/distillation.py @@ -157,7 +157,6 @@ def validate_and_prepare_for_distillation(self, use_topk: bool, topk: Optional[i def _validate_topk_logprobs(self, use_topk: bool, topk: Optional[int]) -> None: if not use_topk: return - if topk is None: raise ValueError("topk must be specified when use_topk is True.") @@ -233,21 +232,15 @@ def get_single_teacher_model(self) -> DistillationTeacherModelConfig: return next(iter(self.teacher_models.values())) def _resolve_teacher_models(self) -> dict[str, DistillationTeacherModelConfig]: - if self.teacher_model.is_configured(is_multi=False) and self.teacher_models: - raise ValueError("Specify either distillation.teacher_model or distillation.teacher_models, not both.") - - teacher_models = {} - for model_name, teacher_model in self.teacher_models.items(): - teacher_model = omega_conf_to_dataclass(teacher_model, dataclass_type=DistillationTeacherModelConfig) - if teacher_model.is_configured(is_multi=True): - teacher_models[model_name] = teacher_model - if teacher_models: + if self.teacher_models: if self.teacher_model.is_configured(is_multi=False): - raise ValueError( - "Multiple teacher models are configured in distillation.teacher_models, " - "but distillation.teacher_model is also configured." - ) - return teacher_models + raise ValueError("Specify either distillation.teacher_model or distillation.teacher_models, not both.") + teacher_models = {} + for model_name, teacher_model in self.teacher_models.items(): + teacher_model = omega_conf_to_dataclass(teacher_model, dataclass_type=DistillationTeacherModelConfig) + if teacher_model.is_configured(is_multi=True): + teacher_models[model_name] = teacher_model + return teacher_models if not self.teacher_model.is_configured(is_multi=False): raise ValueError(