Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/on_policy_distillation_trainer/run_qwen3_vl_geo3k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ MODEL=(

DISTILLATION=(
distillation.enabled=True
distillation.num_workers=8
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.teacher_model.nnodes=1
distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.nnodes=1
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
distillation.teacher_model.inference.tensor_model_parallel_size=1
distillation.teacher_model.inference.name=$ROLLOUT_NAME
Expand Down
7 changes: 3 additions & 4 deletions examples/on_policy_distillation_trainer/run_qwen_gsm8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ MODEL=(

DISTILLATION=(
distillation.enabled=True
distillation.num_workers=8
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.teacher_model.nnodes=1
distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.nnodes=1
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
distillation.teacher_model.inference.tensor_model_parallel_size=1
distillation.teacher_model.inference.name=$ROLLOUT_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ MODEL=(

DISTILLATION=(
distillation.enabled=True
distillation.num_workers=8
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.teacher_model.nnodes=1
distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.nnodes=1
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
distillation.teacher_model.inference.tensor_model_parallel_size=1
distillation.teacher_model.inference.name=$ROLLOUT_NAME
Expand Down
6 changes: 2 additions & 4 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def __init__(
if self.distillation_enabled:
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(self.distillation_config)
self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss
self.stream_teacher_with_rollout = self.distillation_config.teacher_model.enable_resource_pool
self.stream_teacher_with_rollout = self.distillation_config.enable_resource_pool

if self.stream_teacher_with_rollout:
if teacher_servers is None:
Expand Down Expand Up @@ -1012,9 +1012,7 @@ def __init__(

self.teacher_model_manager = teacher_model_manager
self.distillation_enabled = is_distillation_enabled(self.config.get("distillation", None))
self.stream_teacher_with_rollout = (
self.distillation_enabled and self.config.distillation.teacher_model.enable_resource_pool
)
self.stream_teacher_with_rollout = self.distillation_enabled and self.config.distillation.enable_resource_pool

assert worker_group is not None or self.rollout_config.nnodes > 0, "nnodes must be > 0 in standalone mode"

Expand Down
11 changes: 6 additions & 5 deletions verl/experimental/teacher_loop/teacher_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
from verl.protocol import DataProto
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.tokenizer import normalize_token_ids
from verl.workers.config import DistillationConfig, DistillationLossConfig
from verl.workers.config import DistillationConfig, DistillationLossConfig, DistillationTeacherModelConfig


def _get_teacher_sampling_params(
distillation_config: DistillationConfig,
teacher_model_config: DistillationTeacherModelConfig,
distillation_loss_config: DistillationLossConfig,
) -> dict[str, Any]:
"""Get sampling parameters for teacher model when computing log probabilities for distillation."""
if distillation_config.teacher_model.inference.temperature != 1.0:
if teacher_model_config.inference.temperature != 1.0:
raise NotImplementedError("vLLM does not support temperature for prompt_logprobs.")

num_logprobs = distillation_loss_config.topk if distillation_loss_config.loss_settings.use_topk else 0
return {
"max_tokens": 1,
"temperature": distillation_config.teacher_model.inference.temperature,
"temperature": teacher_model_config.inference.temperature,
"prompt_logprobs": num_logprobs,
}

Expand Down Expand Up @@ -102,6 +102,7 @@ def __init__(
else:
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(distillation_config)
self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss
self.teacher_model_config = self.distillation_config.get_single_teacher_model()
self.pad_token_id = pad_token_id

async def compute_teacher_logprobs_single(
Expand All @@ -114,7 +115,7 @@ async def compute_teacher_logprobs_single(
teacher_output = await self.generate(
request_id=uuid4().hex,
prompt_ids=sequence_ids,
sampling_params=_get_teacher_sampling_params(self.distillation_config, self.distillation_loss_config),
sampling_params=_get_teacher_sampling_params(self.teacher_model_config, self.distillation_loss_config),
image_data=multi_modal_data.get("images"),
video_data=multi_modal_data.get("videos"),
)
Expand Down
10 changes: 7 additions & 3 deletions verl/experimental/teacher_loop/teacher_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self.sleep()

def _initialize_llm_servers(self):
teacher_model_config: DistillationTeacherModelConfig = self.config.teacher_model
teacher_model_config: DistillationTeacherModelConfig = self.config.get_single_teacher_model()
teacher_world_size = (
teacher_model_config.inference.tensor_model_parallel_size
* teacher_model_config.inference.data_parallel_size
Expand All @@ -63,8 +63,12 @@ def _initialize_llm_servers(self):
world_size = (
self.resource_pool.world_size
if self.resource_pool # colocate mode
else teacher_model_config.n_gpus_per_node * teacher_model_config.nnodes # standalone mode
else self.config.n_gpus_per_node * self.config.nnodes # standalone mode
)
if world_size % teacher_world_size != 0:
raise ValueError(
f"Teacher world size {teacher_world_size} must divide allocated resource pool size {world_size}."
)
num_replicas = world_size // teacher_world_size

rollout_replica_class = get_rollout_replica_class(teacher_model_config.inference.name)
Expand All @@ -80,7 +84,7 @@ def _initialize_llm_servers(self):
replica_rank=replica_rank,
config=rollout_config,
model_config=model_config,
gpus_per_node=teacher_model_config.n_gpus_per_node,
gpus_per_node=self.config.n_gpus_per_node,
is_teacher_model=True,
)
for replica_rank in range(num_replicas)
Expand Down
9 changes: 5 additions & 4 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,6 @@ algorithm:
distillation:
_target_: verl.workers.config.DistillationConfig
enabled: false
num_workers: 8
distillation_loss:
_target_: verl.workers.config.DistillationLossConfig
loss_mode: k3
Expand All @@ -749,11 +748,13 @@ distillation:
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
teacher_models: {}
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
task: null
model_path: null
inference:
_target_: verl.workers.config.RolloutConfig
Expand Down
9 changes: 5 additions & 4 deletions verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,6 @@ algorithm:
distillation:
_target_: verl.workers.config.DistillationConfig
enabled: false
num_workers: 8
distillation_loss:
_target_: verl.workers.config.DistillationLossConfig
loss_mode: k3
Expand All @@ -642,11 +641,13 @@ distillation:
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
teacher_models: {}
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
task: null
model_path: null
inference:
_target_: verl.workers.config.RolloutConfig
Expand Down
9 changes: 5 additions & 4 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,6 @@ algorithm:
distillation:
_target_: verl.workers.config.DistillationConfig
enabled: false
num_workers: 8
distillation_loss:
_target_: verl.workers.config.DistillationLossConfig
loss_mode: k3
Expand All @@ -674,11 +673,13 @@ distillation:
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
teacher_models: {}
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
task: null
model_path: null
inference:
_target_: verl.workers.config.RolloutConfig
Expand Down
9 changes: 5 additions & 4 deletions verl/trainer/config/_generated_ppo_veomni_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,6 @@ algorithm:
distillation:
_target_: verl.workers.config.DistillationConfig
enabled: false
num_workers: 8
distillation_loss:
_target_: verl.workers.config.DistillationLossConfig
loss_mode: k3
Expand All @@ -619,11 +618,13 @@ distillation:
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
teacher_models: {}
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
task: null
model_path: null
inference:
_target_: verl.workers.config.RolloutConfig
Expand Down
26 changes: 12 additions & 14 deletions verl/trainer/config/distillation/distillation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ defaults:
# Whether to enable distillation.
enabled: false

# we launch num_workers teacher managers to parallelize the teacher logprob computation
num_workers: 8

# distillation loss config
distillation_loss:

Expand Down Expand Up @@ -58,21 +55,23 @@ distillation_loss:
clip_ratio_high: 0.2


# teacher model config
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
# Whether to enable a separate resource pool for distillation teacher model(s)
enable_resource_pool: false

# Whether to enable separate resource pool for teacher model(s)
enable_resource_pool: false
# Number of GPUs per node in the distillation teacher resource pool
n_gpus_per_node: 8

# Number of GPUs per node to use for distillation teacher model(s)
n_gpus_per_node: 8
# Number of nodes in the distillation teacher resource pool
nnodes: 0

# Number of nodes to use for distillation teacher model(s)
nnodes: 0
# multi-teacher configs
teacher_models: {}

# single-teacher config
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
task: null
model_path: null

inference:
_target_: verl.workers.config.RolloutConfig
name: ${oc.select:actor_rollout_ref.rollout.name}
Expand All @@ -94,7 +93,6 @@ teacher_model:
enable_prefix_caching: true
disable_log_stats: true
skip_tokenizer_init: false

prompt_length: ${oc.select:actor_rollout_ref.rollout.prompt_length}
response_length: ${oc.select:actor_rollout_ref.rollout.response_length}
temperature: ${oc.select:actor_rollout_ref.rollout.temperature}
22 changes: 10 additions & 12 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,17 @@ def init_resource_pool_mgr(self, config):

distillation_config = config.get("distillation")
if is_distillation_enabled(distillation_config):
if distillation_config.teacher_model.enable_resource_pool:
if distillation_config.teacher_model.n_gpus_per_node <= 0:
raise ValueError("config.distillation.teacher_model.n_gpus_per_node must be greater than 0")
if distillation_config.teacher_model.nnodes <= 0:
raise ValueError("config.distillation.teacher_model.nnodes must be greater than 0")

teacher_pool = [
distillation_config.teacher_model.n_gpus_per_node
] * distillation_config.teacher_model.nnodes
if distillation_config.enable_resource_pool:
if distillation_config.n_gpus_per_node <= 0:
raise ValueError("config.distillation.n_gpus_per_node must be greater than 0")
if distillation_config.nnodes <= 0:
raise ValueError("config.distillation.nnodes must be greater than 0")

teacher_pool = [distillation_config.n_gpus_per_node] * distillation_config.nnodes
resource_pool_spec["teacher_pool"] = teacher_pool
else:
distillation_config.teacher_model.nnodes = config.trainer.nnodes
distillation_config.teacher_model.n_gpus_per_node = config.trainer.n_gpus_per_node
distillation_config.nnodes = config.trainer.nnodes
distillation_config.n_gpus_per_node = config.trainer.n_gpus_per_node

from verl.trainer.ppo.ray_trainer import ResourcePoolManager

Expand All @@ -281,7 +279,7 @@ def add_teacher_model_resource_pool(self, config):
if is_distillation_enabled(config.get("distillation")):
# we do not use teacher model workers, so we only register teacher model in resource pool
# without registering a teacher model worker in role-worker mapping
if config.distillation.teacher_model.enable_resource_pool:
if config.distillation.enable_resource_pool:
self.mapping[Role.TeacherModel] = "teacher_pool"
else:
self.mapping[Role.TeacherModel] = "global_pool"
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _compute_reward_colocate(self, batch: DataProto) -> tuple[torch.Tensor, dict
return batch_reward

def _should_compute_teacher_colocate(self, batch: DataProto) -> bool:
return self.use_teacher_policy and not self.distillation_config.teacher_model.enable_resource_pool
return self.use_teacher_policy and not self.distillation_config.enable_resource_pool

def _compute_teacher_colocate(self, batch: DataProto) -> DataProto:
"""Compute teacher logprobs after rollout when teacher and student are colocated."""
Expand Down
Loading
Loading