Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/on_policy_distillation_trainer/run_qwen3_vl_geo3k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ MODEL=(

DISTILLATION=(
distillation.enabled=True
distillation.num_workers=8
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.teacher_model.nnodes=1
distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.nnodes=1
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
distillation.teacher_model.inference.tensor_model_parallel_size=1
distillation.teacher_model.inference.name=$ROLLOUT_NAME
Expand Down
7 changes: 3 additions & 4 deletions examples/on_policy_distillation_trainer/run_qwen_gsm8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ MODEL=(

DISTILLATION=(
distillation.enabled=True
distillation.num_workers=8
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.teacher_model.nnodes=1
distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.nnodes=1
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
distillation.teacher_model.inference.tensor_model_parallel_size=1
distillation.teacher_model.inference.name=$ROLLOUT_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ MODEL=(

DISTILLATION=(
distillation.enabled=True
distillation.num_workers=8
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.teacher_model.nnodes=1
distillation.enable_resource_pool=$TEACHER_RESOURCE_POOL
distillation.n_gpus_per_node=$TEACHER_WORLD_SIZE
distillation.nnodes=1
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
distillation.teacher_model.inference.tensor_model_parallel_size=1
distillation.teacher_model.inference.name=$ROLLOUT_NAME
Expand Down
17 changes: 7 additions & 10 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config
from verl.experimental.agent_loop.utils import resolve_config_path
from verl.experimental.teacher_loop import TeacherModelManager
from verl.experimental.teacher_loop import MultiTeacherModelManager
from verl.protocol import DataProto
from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup
from verl.trainer.distillation import is_distillation_enabled
Expand Down Expand Up @@ -465,7 +465,7 @@ def __init__(
if self.distillation_enabled:
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(self.distillation_config)
self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss
self.stream_teacher_with_rollout = self.distillation_config.teacher_model.enable_resource_pool
self.stream_teacher_with_rollout = self.distillation_config.enable_resource_pool

if self.stream_teacher_with_rollout:
if teacher_servers is None:
Expand All @@ -479,8 +479,6 @@ def __init__(
config,
teacher_servers,
load_balancer_handle=teacher_load_balancer_handle,
distillation_config=self.distillation_config,
pad_token_id=self.model_config.tokenizer.pad_token_id,
)
else:
self.teacher_server_manager = None
Expand Down Expand Up @@ -1024,7 +1022,8 @@ class AgentLoopManager:
config (DictConfig): whole config for main entrypoint.
worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode.
rollout_resource_pool (RayResourcePool): Resource pool for hybrid mode, only used by TensorRT-LLM.
teacher_model_manager (TeacherModelManager): Manager for streaming teacher computation, used for distillation.
teacher_model_manager (MultiTeacherModelManager): Manager for streaming teacher computation, used for
distillation.
reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation.
"""

Expand All @@ -1033,7 +1032,7 @@ def __init__(
config: DictConfig,
worker_group: RayWorkerGroup = None,
rollout_resource_pool: RayResourcePool = None,
teacher_model_manager: TeacherModelManager = None,
teacher_model_manager: MultiTeacherModelManager = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
self.config = config
Expand All @@ -1044,9 +1043,7 @@ def __init__(

self.teacher_model_manager = teacher_model_manager
self.distillation_enabled = is_distillation_enabled(self.config.get("distillation", None))
self.stream_teacher_with_rollout = (
self.distillation_enabled and self.config.distillation.teacher_model.enable_resource_pool
)
self.stream_teacher_with_rollout = self.distillation_enabled and self.config.distillation.enable_resource_pool

assert worker_group is not None or self.rollout_config.nnodes > 0, "nnodes must be > 0 in standalone mode"

Expand All @@ -1069,7 +1066,7 @@ async def create(
worker_group: RayWorkerGroup = None,
rollout_resource_pool: RayResourcePool = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
teacher_model_manager: TeacherModelManager = None,
teacher_model_manager: MultiTeacherModelManager = None,
):
"""Create agent loop manager."""
instance = cls(config, worker_group, rollout_resource_pool, teacher_model_manager, reward_loop_worker_handles)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
AsyncLLMServerManager,
TokenOutput,
)
from verl.experimental.teacher_loop import TeacherModelManager
from verl.experimental.teacher_loop import MultiTeacherModelManager
from verl.protocol import DataProto
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup
from verl.utils.ray_utils import auto_await
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(
config: DictConfig,
worker_group: RayWorkerGroup = None,
rollout_resource_pool: RayResourcePool = None,
teacher_model_manager: TeacherModelManager = None,
teacher_model_manager: MultiTeacherModelManager = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
self.agent_loop_workers_class = FullyAsyncAgentLoopWorker
Expand Down
4 changes: 2 additions & 2 deletions verl/experimental/teacher_loop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .teacher_model import TeacherModelManager
from .teacher_model import MultiTeacherModelManager

__all__ = ["TeacherModelManager"]
__all__ = ["MultiTeacherModelManager"]
56 changes: 40 additions & 16 deletions verl/experimental/teacher_loop/teacher_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,26 @@
from verl.protocol import DataProto
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.tokenizer import normalize_token_ids
from verl.workers.config import DistillationConfig, DistillationLossConfig
from verl.workers.config import (
DistillationConfig,
DistillationLossConfig,
DistillationTeacherModelConfig,
HFModelConfig,
)


def _get_teacher_sampling_params(
distillation_config: DistillationConfig,
teacher_model_config: DistillationTeacherModelConfig,
distillation_loss_config: DistillationLossConfig,
) -> dict[str, Any]:
"""Get sampling parameters for teacher model when computing log probabilities for distillation."""
if distillation_config.teacher_model.inference.temperature != 1.0:
if teacher_model_config.inference.temperature != 1.0:
raise NotImplementedError("vLLM does not support temperature for prompt_logprobs.")

num_logprobs = distillation_loss_config.topk if distillation_loss_config.loss_settings.use_topk else 0
return {
"max_tokens": 1,
"temperature": distillation_config.teacher_model.inference.temperature,
"temperature": teacher_model_config.inference.temperature,
"prompt_logprobs": num_logprobs,
}

Expand Down Expand Up @@ -85,24 +90,42 @@ def _unpad_teacher_inputs(data: DataProto) -> tuple[list[int], int, int]:
return sequence_ids, valid_prompt_length, valid_response_length


class AsyncTeacherLLMServerManager(AsyncLLMServerManager):
"""Teacher-specific async client used for distillation logprob computation."""
class AsyncTeacherLLMServerManager:
"""Teacher-specific async client used for distillation logprob computation.
TODO: MOPD -- servers and load_balance_handle become a dict (one per teacher)
"""

def __init__(
self,
config: DictConfig,
servers: list[tuple[str, ray.actor.ActorHandle]],
load_balancer_handle: ray.actor.ActorHandle,
distillation_config: DictConfig | DistillationConfig,
pad_token_id: int,
):
super().__init__(config=config, servers=servers, load_balancer_handle=load_balancer_handle)
if isinstance(distillation_config, DistillationConfig):
self.distillation_config = distillation_config
else:
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(distillation_config)
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(config.distillation)
self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss
self.pad_token_id = pad_token_id
teacher_model_config: DistillationTeacherModelConfig = self.distillation_config.get_single_teacher_model()
self.teacher_model_config = teacher_model_config

# Get pad token ID
model_config = HFModelConfig(path=teacher_model_config.model_path)
text_tokenizer = model_config.tokenizer
if text_tokenizer is None:
raise ValueError(f"Tokenizer is required for teacher model {teacher_model_config.model_path}")
self.pad_token_id = text_tokenizer.pad_token_id
Comment on lines +110 to +114
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Initializing HFModelConfig and accessing model_config.tokenizer inside __init__ may trigger redundant and expensive tokenizer loading every time an AsyncTeacherLLMServerManager is instantiated. Since this manager is initialized in the AgentLoopWorker, which can be numerous, this could lead to significant overhead and memory pressure. It is recommended to pass the pad_token_id directly or ensure the tokenizer is cached.

self._initialize_teacher_server_managers(
config=config, servers=servers, load_balancer_handle=load_balancer_handle
)

def _initialize_teacher_server_managers(
self,
config: DictConfig,
servers: list[tuple[str, ray.actor.ActorHandle]],
load_balancer_handle: ray.actor.ActorHandle,
):
# TODO: MOPD -- balancers/servers become a dict (one per teacher)
self.server_manager = AsyncLLMServerManager(
config=config, servers=servers, load_balancer_handle=load_balancer_handle
)

async def compute_teacher_logprobs_single(
self,
Expand All @@ -111,10 +134,11 @@ async def compute_teacher_logprobs_single(
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute teacher log probabilities for a single unpadded sequence."""
multi_modal_data = multi_modal_data or {}
teacher_output = await self.generate(
# TODO: MOPD -- select server manager from server manager dict based on example task key
teacher_output = await self.server_manager.generate(
request_id=uuid4().hex,
prompt_ids=sequence_ids,
sampling_params=_get_teacher_sampling_params(self.distillation_config, self.distillation_loss_config),
sampling_params=_get_teacher_sampling_params(self.teacher_model_config, self.distillation_loss_config),
image_data=multi_modal_data.get("images"),
video_data=multi_modal_data.get("videos"),
)
Expand Down
Loading
Loading