Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/on_policy_distillation_trainer/run_qwen3_vl_geo3k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ MODEL=(

DISTILLATION=(
distillation.enabled=True
distillation.num_workers=8
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.teacher_model.nnodes=1
distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.nnodes=1
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
distillation.teacher_model.inference.tensor_model_parallel_size=1
distillation.teacher_model.inference.name=$ROLLOUT_NAME
Expand Down
7 changes: 3 additions & 4 deletions examples/on_policy_distillation_trainer/run_qwen_gsm8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ MODEL=(

DISTILLATION=(
distillation.enabled=True
distillation.num_workers=8
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.teacher_model.nnodes=1
distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.nnodes=1
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
distillation.teacher_model.inference.tensor_model_parallel_size=1
distillation.teacher_model.inference.name=$ROLLOUT_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ MODEL=(

DISTILLATION=(
distillation.enabled=True
distillation.num_workers=8
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.teacher_model.nnodes=1
distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.nnodes=1
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
distillation.teacher_model.inference.tensor_model_parallel_size=1
distillation.teacher_model.inference.name=$ROLLOUT_NAME
Expand Down
6 changes: 2 additions & 4 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def __init__(
if self.distillation_enabled:
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(self.distillation_config)
self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss
self.stream_teacher_with_rollout = self.distillation_config.teacher_model.enable_resource_pool
self.stream_teacher_with_rollout = self.distillation_config.enable_resource_pool

if self.stream_teacher_with_rollout:
if teacher_servers is None:
Expand Down Expand Up @@ -1019,9 +1019,7 @@ def __init__(

self.teacher_model_manager = teacher_model_manager
self.distillation_enabled = is_distillation_enabled(self.config.get("distillation", None))
self.stream_teacher_with_rollout = (
self.distillation_enabled and self.config.distillation.teacher_model.enable_resource_pool
)
self.stream_teacher_with_rollout = self.distillation_enabled and self.config.distillation.enable_resource_pool

assert worker_group is not None or self.rollout_config.nnodes > 0, "nnodes must be > 0 in standalone mode"

Expand Down
11 changes: 6 additions & 5 deletions verl/experimental/teacher_loop/teacher_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
from verl.protocol import DataProto
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.tokenizer import normalize_token_ids
from verl.workers.config import DistillationConfig, DistillationLossConfig
from verl.workers.config import DistillationConfig, DistillationLossConfig, DistillationTeacherModelConfig


def _get_teacher_sampling_params(
distillation_config: DistillationConfig,
teacher_model_config: DistillationTeacherModelConfig,
distillation_loss_config: DistillationLossConfig,
) -> dict[str, Any]:
"""Get sampling parameters for teacher model when computing log probabilities for distillation."""
if distillation_config.teacher_model.inference.temperature != 1.0:
if teacher_model_config.inference.temperature != 1.0:
raise NotImplementedError("vLLM does not support temperature for prompt_logprobs.")

num_logprobs = distillation_loss_config.topk if distillation_loss_config.loss_settings.use_topk else 0
return {
"max_tokens": 1,
"temperature": distillation_config.teacher_model.inference.temperature,
"temperature": teacher_model_config.inference.temperature,
"prompt_logprobs": num_logprobs,
}

Expand Down Expand Up @@ -102,6 +102,7 @@ def __init__(
else:
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(distillation_config)
self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss
self.teacher_model_config = self.distillation_config.get_single_teacher_model()
self.pad_token_id = pad_token_id

async def compute_teacher_logprobs_single(
Expand All @@ -114,7 +115,7 @@ async def compute_teacher_logprobs_single(
teacher_output = await self.generate(
request_id=uuid4().hex,
prompt_ids=sequence_ids,
sampling_params=_get_teacher_sampling_params(self.distillation_config, self.distillation_loss_config),
sampling_params=_get_teacher_sampling_params(self.teacher_model_config, self.distillation_loss_config),
image_data=multi_modal_data.get("images"),
video_data=multi_modal_data.get("videos"),
)
Expand Down
10 changes: 7 additions & 3 deletions verl/experimental/teacher_loop/teacher_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self.sleep()

def _initialize_llm_servers(self):
teacher_model_config: DistillationTeacherModelConfig = self.config.teacher_model
teacher_model_config: DistillationTeacherModelConfig = self.config.get_single_teacher_model()
teacher_world_size = (
teacher_model_config.inference.tensor_model_parallel_size
* teacher_model_config.inference.data_parallel_size
Expand All @@ -63,8 +63,12 @@ def _initialize_llm_servers(self):
world_size = (
self.resource_pool.world_size
if self.resource_pool # colocate mode
else teacher_model_config.n_gpus_per_node * teacher_model_config.nnodes # standalone mode
else self.config.n_gpus_per_node * self.config.nnodes # standalone mode
)
if world_size % teacher_world_size != 0:
raise ValueError(
f"Teacher world size {teacher_world_size} must divide allocated resource pool size {world_size}."
)
num_replicas = world_size // teacher_world_size

rollout_replica_class = get_rollout_replica_class(teacher_model_config.inference.name)
Expand All @@ -80,7 +84,7 @@ def _initialize_llm_servers(self):
replica_rank=replica_rank,
config=rollout_config,
model_config=model_config,
gpus_per_node=teacher_model_config.n_gpus_per_node,
gpus_per_node=self.config.n_gpus_per_node,
is_teacher_model=True,
)
for replica_rank in range(num_replicas)
Expand Down
9 changes: 5 additions & 4 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,6 @@ algorithm:
distillation:
_target_: verl.workers.config.DistillationConfig
enabled: false
num_workers: 8
distillation_loss:
_target_: verl.workers.config.DistillationLossConfig
loss_mode: k3
Expand All @@ -783,11 +782,13 @@ distillation:
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
teacher_models: {}
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
task: null
model_path: null
inference:
_target_: verl.workers.config.RolloutConfig
Expand Down
9 changes: 5 additions & 4 deletions verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,6 @@ algorithm:
distillation:
_target_: verl.workers.config.DistillationConfig
enabled: false
num_workers: 8
distillation_loss:
_target_: verl.workers.config.DistillationLossConfig
loss_mode: k3
Expand All @@ -719,11 +718,13 @@ distillation:
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
teacher_models: {}
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
task: null
model_path: null
inference:
_target_: verl.workers.config.RolloutConfig
Expand Down
9 changes: 5 additions & 4 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,6 @@ algorithm:
distillation:
_target_: verl.workers.config.DistillationConfig
enabled: false
num_workers: 8
distillation_loss:
_target_: verl.workers.config.DistillationLossConfig
loss_mode: k3
Expand All @@ -741,11 +740,13 @@ distillation:
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
teacher_models: {}
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
task: null
model_path: null
inference:
_target_: verl.workers.config.RolloutConfig
Expand Down
9 changes: 5 additions & 4 deletions verl/trainer/config/_generated_ppo_veomni_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,6 @@ algorithm:
distillation:
_target_: verl.workers.config.DistillationConfig
enabled: false
num_workers: 8
distillation_loss:
_target_: verl.workers.config.DistillationLossConfig
loss_mode: k3
Expand All @@ -696,11 +695,13 @@ distillation:
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
teacher_models: {}
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
enable_resource_pool: false
n_gpus_per_node: 8
nnodes: 0
task: null
model_path: null
inference:
_target_: verl.workers.config.RolloutConfig
Expand Down
26 changes: 12 additions & 14 deletions verl/trainer/config/distillation/distillation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ defaults:
# Whether to enable distillation.
enabled: false

# we launch num_workers teacher managers to parallelize the teacher logprob computation
num_workers: 8

# distillation loss config
distillation_loss:

Expand Down Expand Up @@ -58,21 +55,23 @@ distillation_loss:
clip_ratio_high: 0.2


# teacher model config
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
# Whether to enable a separate resource pool for distillation teacher model(s)
enable_resource_pool: false

# Whether to enable separate resource pool for teacher model(s)
enable_resource_pool: false
# Number of GPUs per node in the distillation teacher resource pool
n_gpus_per_node: 8

# Number of GPUs per node to use for distillation teacher model(s)
n_gpus_per_node: 8
# Number of nodes in the distillation teacher resource pool
nnodes: 0

# Number of nodes to use for distillation teacher model(s)
nnodes: 0
# multi-teacher configs
teacher_models: {}

# single-teacher config
teacher_model:
_target_: verl.workers.config.DistillationTeacherModelConfig
task: null
model_path: null

inference:
_target_: verl.workers.config.RolloutConfig
name: ${oc.select:actor_rollout_ref.rollout.name}
Expand All @@ -94,7 +93,6 @@ teacher_model:
enable_prefix_caching: true
disable_log_stats: true
skip_tokenizer_init: false

prompt_length: ${oc.select:actor_rollout_ref.rollout.prompt_length}
response_length: ${oc.select:actor_rollout_ref.rollout.response_length}
temperature: ${oc.select:actor_rollout_ref.rollout.temperature}
22 changes: 10 additions & 12 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,17 @@ def init_resource_pool_mgr(self, config):

distillation_config = config.get("distillation")
if is_distillation_enabled(distillation_config):
if distillation_config.teacher_model.enable_resource_pool:
if distillation_config.teacher_model.n_gpus_per_node <= 0:
raise ValueError("config.distillation.teacher_model.n_gpus_per_node must be greater than 0")
if distillation_config.teacher_model.nnodes <= 0:
raise ValueError("config.distillation.teacher_model.nnodes must be greater than 0")

teacher_pool = [
distillation_config.teacher_model.n_gpus_per_node
] * distillation_config.teacher_model.nnodes
if distillation_config.enable_resource_pool:
if distillation_config.n_gpus_per_node <= 0:
raise ValueError("config.distillation.n_gpus_per_node must be greater than 0")
if distillation_config.nnodes <= 0:
raise ValueError("config.distillation.nnodes must be greater than 0")

teacher_pool = [distillation_config.n_gpus_per_node] * distillation_config.nnodes
resource_pool_spec["teacher_pool"] = teacher_pool
else:
distillation_config.teacher_model.nnodes = config.trainer.nnodes
distillation_config.teacher_model.n_gpus_per_node = config.trainer.n_gpus_per_node
distillation_config.nnodes = config.trainer.nnodes
distillation_config.n_gpus_per_node = config.trainer.n_gpus_per_node

from verl.trainer.ppo.ray_trainer import ResourcePoolManager

Expand All @@ -281,7 +279,7 @@ def add_teacher_model_resource_pool(self, config):
if is_distillation_enabled(config.get("distillation")):
# we do not use teacher model workers, so we only register teacher model in resource pool
# without registering a teacher model worker in role-worker mapping
if config.distillation.teacher_model.enable_resource_pool:
if config.distillation.enable_resource_pool:
self.mapping[Role.TeacherModel] = "teacher_pool"
else:
self.mapping[Role.TeacherModel] = "global_pool"
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _compute_reward_colocate(self, batch: DataProto) -> tuple[torch.Tensor, dict
return batch_reward

def _should_compute_teacher_colocate(self, batch: DataProto) -> bool:
return self.use_teacher_policy and not self.distillation_config.teacher_model.enable_resource_pool
return self.use_teacher_policy and not self.distillation_config.enable_resource_pool

def _compute_teacher_colocate(self, batch: DataProto) -> DataProto:
"""Compute teacher logprobs after rollout when teacher and student are colocated."""
Expand Down
Loading
Loading