Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config
from verl.experimental.agent_loop.utils import resolve_config_path
from verl.experimental.teacher_loop import TeacherModelManager
from verl.experimental.teacher_loop import MultiTeacherModelManager
from verl.protocol import DataProto
from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup
from verl.trainer.distillation import is_distillation_enabled
Expand Down Expand Up @@ -454,8 +454,6 @@ def __init__(
config,
teacher_servers,
load_balancer_handle=teacher_load_balancer_handle,
distillation_config=self.distillation_config,
pad_token_id=self.model_config.tokenizer.pad_token_id,
)
else:
self.teacher_server_manager = None
Expand Down Expand Up @@ -999,7 +997,8 @@ class AgentLoopManager:
config (DictConfig): whole config for main entrypoint.
worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode.
rollout_resource_pool (RayResourcePool): Resource pool for hybrid mode, only used by TensorRT-LLM.
teacher_model_manager (TeacherModelManager): Manager for streaming teacher computation, used for distillation.
teacher_model_manager (MultiTeacherModelManager): Manager for streaming teacher computation, used for
distillation.
reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation.
"""

Expand All @@ -1008,7 +1007,7 @@ def __init__(
config: DictConfig,
worker_group: RayWorkerGroup = None,
rollout_resource_pool: RayResourcePool = None,
teacher_model_manager: TeacherModelManager = None,
teacher_model_manager: MultiTeacherModelManager = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
self.config = config
Expand Down Expand Up @@ -1044,7 +1043,7 @@ async def create(
worker_group: RayWorkerGroup = None,
rollout_resource_pool: RayResourcePool = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
teacher_model_manager: TeacherModelManager = None,
teacher_model_manager: MultiTeacherModelManager = None,
):
"""Create agent loop manager."""
instance = cls(config, worker_group, rollout_resource_pool, teacher_model_manager, reward_loop_worker_handles)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
AsyncLLMServerManager,
TokenOutput,
)
from verl.experimental.teacher_loop import TeacherModelManager
from verl.experimental.teacher_loop import MultiTeacherModelManager
from verl.protocol import DataProto
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup
from verl.utils.ray_utils import auto_await
Expand Down Expand Up @@ -152,7 +152,7 @@ def __init__(
config: DictConfig,
worker_group: RayWorkerGroup = None,
rollout_resource_pool: RayResourcePool = None,
teacher_model_manager: TeacherModelManager = None,
teacher_model_manager: MultiTeacherModelManager = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
self.agent_loop_workers_class = FullyAsyncAgentLoopWorker
Expand Down
4 changes: 2 additions & 2 deletions verl/experimental/teacher_loop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .teacher_model import TeacherModelManager
from .teacher_model import MultiTeacherModelManager

__all__ = ["TeacherModelManager"]
__all__ = ["MultiTeacherModelManager"]
56 changes: 40 additions & 16 deletions verl/experimental/teacher_loop/teacher_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,26 @@
from verl.protocol import DataProto
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.tokenizer import normalize_token_ids
from verl.workers.config import DistillationConfig, DistillationLossConfig
from verl.workers.config import (
DistillationConfig,
DistillationLossConfig,
DistillationTeacherModelConfig,
HFModelConfig,
)


def _get_teacher_sampling_params(
distillation_config: DistillationConfig,
teacher_config: DistillationTeacherModelConfig,
distillation_loss_config: DistillationLossConfig,
) -> dict[str, Any]:
"""Get sampling parameters for teacher model when computing log probabilities for distillation."""
if distillation_config.teacher_model.inference.temperature != 1.0:
if teacher_config.inference.temperature != 1.0:
raise NotImplementedError("vLLM does not support temperature for prompt_logprobs.")

num_logprobs = distillation_loss_config.topk if distillation_loss_config.loss_settings.use_topk else 0
return {
"max_tokens": 1,
"temperature": distillation_config.teacher_model.inference.temperature,
"temperature": teacher_config.inference.temperature,
"prompt_logprobs": num_logprobs,
}

Expand Down Expand Up @@ -85,24 +90,42 @@ def _unpad_teacher_inputs(data: DataProto) -> tuple[list[int], int, int]:
return sequence_ids, valid_prompt_length, valid_response_length


class AsyncTeacherLLMServerManager(AsyncLLMServerManager):
"""Teacher-specific async client used for distillation logprob computation."""
class AsyncTeacherLLMServerManager:
"""Teacher-specific async client used for distillation logprob computation.
TODO: MOPD -- servers and load_balance_handle become a dict (one per teacher)
"""

def __init__(
self,
config: DictConfig,
servers: list[tuple[str, ray.actor.ActorHandle]],
load_balancer_handle: ray.actor.ActorHandle,
distillation_config: DictConfig | DistillationConfig,
pad_token_id: int,
):
super().__init__(config=config, servers=servers, load_balancer_handle=load_balancer_handle)
if isinstance(distillation_config, DistillationConfig):
self.distillation_config = distillation_config
else:
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(distillation_config)
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(config.distillation)
self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss
self.pad_token_id = pad_token_id
self.teacher_config: DistillationTeacherModelConfig = self.distillation_config.teacher_model

# Get pad token ID
teacher_model_config: DistillationTeacherModelConfig = self.distillation_config.teacher_model
model_config = HFModelConfig(path=teacher_model_config.model_path)
text_tokenizer = model_config.tokenizer
if model_config.tokenizer is None:
raise ValueError(f"Tokenizer is required for teacher model {teacher_model_config.model_path}")
self.pad_token_id = text_tokenizer.pad_token_id
Comment on lines +110 to +114
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Initializing HFModelConfig and accessing model_config.tokenizer inside __init__ may trigger redundant and expensive tokenizer loading every time an AsyncTeacherLLMServerManager is instantiated. Since this manager is initialized in the AgentLoopWorker, which can be numerous, this could lead to significant overhead and memory pressure. It is recommended to pass the pad_token_id directly or ensure the tokenizer is cached.

self._initialize_teacher_server_managers(
config=config, servers=servers, load_balancer_handle=load_balancer_handle
)

def _initialize_teacher_server_managers(
self,
config: DictConfig,
servers: list[tuple[str, ray.actor.ActorHandle]],
load_balancer_handle: ray.actor.ActorHandle,
):
# TODO: MOPD -- balancers/servers become a dict (one per teacher)
self.server_manager = AsyncLLMServerManager(
config=config, servers=servers, load_balancer_handle=load_balancer_handle
)

async def compute_teacher_logprobs_single(
self,
Expand All @@ -111,10 +134,11 @@ async def compute_teacher_logprobs_single(
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute teacher log probabilities for a single unpadded sequence."""
multi_modal_data = multi_modal_data or {}
teacher_output = await self.generate(
# TODO: MOPD -- select server manager from server manager dict based on example task key
teacher_output = await self.server_manager.generate(
request_id=uuid4().hex,
prompt_ids=sequence_ids,
sampling_params=_get_teacher_sampling_params(self.distillation_config, self.distillation_loss_config),
sampling_params=_get_teacher_sampling_params(self.teacher_config, self.distillation_loss_config),
image_data=multi_modal_data.get("images"),
video_data=multi_modal_data.get("videos"),
)
Expand Down
131 changes: 94 additions & 37 deletions verl/experimental/teacher_loop/teacher_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,45 @@
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


@auto_await
async def _run_all(tasks: list[asyncio.Task]):
await asyncio.gather(*tasks)
Comment on lines +32 to +33
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hint for tasks in _run_all is list[asyncio.Task], but the function is called with a list of coroutines (e.g., in TeacherModelManager.wake_up). While asyncio.gather accepts coroutines, the type hint is technically incorrect and might be misleading for static analysis tools.

Suggested change
async def _run_all(tasks: list[asyncio.Task]):
await asyncio.gather(*tasks)
async def _run_all(tasks: list):
await asyncio.gather(*tasks)



def _run_single(task):
async def run():
return await task

return asyncio.run(run())


class TeacherModelManager:
"""Teacher model manager."""

def __init__(
self,
config: DictConfig,
resource_pool: RayResourcePool = None,
teacher_model_config: DistillationTeacherModelConfig,
resource_pool: RayResourcePool,
):
"""
Initialize the teacher model manager.

Args:
config (DictConfig): Teacher model configuration.
resource_pool (RayResourcePool, optional): Resource pool. Defaults to None.
teacher_model_config (DistillationTeacherModelConfig): Teacher model configuration.
resource_pool (RayResourcePool): Resource pool.
"""

# Need dataclass conversion for max_logprobs handling in post_init
self.config: DistillationConfig = omega_conf_to_dataclass(config)
self.teacher_model_config = teacher_model_config
self.resource_pool = resource_pool
self._initialize_llm_servers()
self._initialize_async_server_manager()
self._initialize_load_balancer_handle()
self._initialize_router()

self.sleep()

def _initialize_llm_servers(self):
teacher_model_config: DistillationTeacherModelConfig = self.config.teacher_model
teacher_model_config = self.teacher_model_config
teacher_world_size = (
teacher_model_config.inference.tensor_model_parallel_size
* teacher_model_config.inference.data_parallel_size
Expand All @@ -71,10 +83,6 @@ def _initialize_llm_servers(self):
rollout_config = teacher_model_config.inference
model_config = HFModelConfig(path=teacher_model_config.model_path)
self.tokenizer = model_config.get_processor()
text_tokenizer = model_config.tokenizer
if model_config.tokenizer is None:
raise ValueError(f"Tokenizer is required for teacher model {teacher_model_config.model_path}")
self.pad_token_id = text_tokenizer.pad_token_id
self.rollout_replicas = [
rollout_replica_class(
replica_rank=replica_rank,
Expand All @@ -88,31 +96,23 @@ def _initialize_llm_servers(self):
if self.resource_pool:
split_resource_pools = split_resource_pool(self.resource_pool, split_size=teacher_world_size)
assert len(split_resource_pools) == len(self.rollout_replicas)
self._run_all(
_run_all(
[
server.init_colocated(resource_pool)
for server, resource_pool in zip(self.rollout_replicas, split_resource_pools, strict=True)
]
)
else:
self._run_all([server.init_standalone() for server in self.rollout_replicas])
_run_all([server.init_standalone() for server in self.rollout_replicas])
self.server_handles = [server._server_handle for server in self.rollout_replicas]
self.server_addresses = [server._server_address for server in self.rollout_replicas]

def _initialize_async_server_manager(self):
def _initialize_load_balancer_handle(self):
from verl.experimental.agent_loop.agent_loop import GlobalRequestLoadBalancer
from verl.experimental.teacher_loop.teacher_manager import AsyncTeacherLLMServerManager

self.load_balancer_handle = GlobalRequestLoadBalancer.remote(
server_actor_ids=self.server_addresses,
)
self.server_manager = AsyncTeacherLLMServerManager(
config=self.config,
servers=list(zip(self.server_addresses, self.server_handles, strict=True)),
load_balancer_handle=self.load_balancer_handle,
distillation_config=self.config,
pad_token_id=self.pad_token_id,
)

def _initialize_router(self):
worker_urls = [f"http://{server_address}" for server_address in self.server_addresses]
Expand All @@ -124,29 +124,86 @@ def _initialize_router(self):
def get_router_address(self):
return self.router_address

@auto_await
async def wake_up(self):
"""Wake up all rollout replica instances."""
await _run_all([replica.wake_up() for replica in self.rollout_replicas])

@auto_await
async def sleep(self):
"""Sleep all rollout replica instances."""
await _run_all([replica.sleep() for replica in self.rollout_replicas])


class MultiTeacherModelManager:
"""Multi Teacher model manager."""

def __init__(
self,
config: DictConfig,
resource_pool: RayResourcePool,
):
"""
Initialize the teacher model manager.

Args:
config (DictConfig): Full configuration.
resource_pool (RayResourcePool): Resource pool.
"""

# Need dataclass conversion for max_logprobs handling in post_init
self.config = config
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(config.distillation)

self.resource_pool = resource_pool
self._initialize_teacher_model_managers()
self._initialize_async_server_manager()
self.sleep()

def _initialize_teacher_model_managers(self):
"""TODO: MOPD -- split resource pool across teachers and init one TeacherModelManager per teacher."""
self.teacher_model_manager = TeacherModelManager(
teacher_model_config=self.distillation_config.teacher_model, resource_pool=self.resource_pool
)
self.server_addresses = self.teacher_model_manager.server_addresses
self.server_handles = self.teacher_model_manager.server_handles

def _initialize_async_server_manager(self):
"""TODO: MOPD -- balancers/servers become a dict (one per teacher)"""
from verl.experimental.teacher_loop.teacher_manager import AsyncTeacherLLMServerManager

self.load_balancer_handle = self.teacher_model_manager.load_balancer_handle
servers = list(zip(self.server_addresses, self.server_handles, strict=True))

# In standalone mode, server manager is initialized in the agent loop
if not self.distillation_config.teacher_model.enable_resource_pool:
self.server_manager = AsyncTeacherLLMServerManager(
config=self.config,
servers=servers,
load_balancer_handle=self.load_balancer_handle,
)

def get_router_address(self):
"""TODO: MOPD -- return dict of router addresses (one per teacher)"""
return self.teacher_model_manager.router_address

def compute_logprobs(self, data):
self.wake_up()
try:
return self._run_single(self.server_manager.compute_teacher_logprobs_batch(data))
return _run_single(self.server_manager.compute_teacher_logprobs_batch(data))
finally:
self.sleep()
Comment on lines 190 to 195
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The compute_logprobs method is synchronous, but it calls self.wake_up() and self.sleep(), which are decorated with @auto_await. When called synchronously, @auto_await typically runs the coroutine to completion using the current event loop or by creating a new one. However, _run_single explicitly uses asyncio.run(), which creates a new event loop and will fail if an event loop is already running in the same thread (which might happen if auto_await left a loop active).

More importantly, mixing auto_await for some calls and asyncio.run (via _run_single) for others in the same synchronous method is fragile and inefficient. It is better to wrap the entire sequence of operations into a single async method and run that once.

Suggested change
def compute_logprobs(self, data):
self.wake_up()
try:
return self._run_single(self.server_manager.compute_teacher_logprobs_batch(data))
return _run_single(self.server_manager.compute_teacher_logprobs_batch(data))
finally:
self.sleep()
def compute_logprobs(self, data):
async def _compute():
await self.wake_up()
try:
return await self.server_manager.compute_teacher_logprobs_batch(data)
finally:
await self.sleep()
return _run_single(_compute())


@auto_await
async def wake_up(self):
"""Wake up all rollout replica instances."""
await self._run_all([replica.wake_up() for replica in self.rollout_replicas])
"""Wake up all rollout replica instances.
TODO: MOPD -- wake up each of the teacher model managers.
"""
await _run_all([manager.wake_up() for manager in [self.teacher_model_manager]])

@auto_await
async def sleep(self):
"""Sleep all rollout replica instances."""
await self._run_all([replica.sleep() for replica in self.rollout_replicas])

@auto_await
async def _run_all(self, tasks: list[asyncio.Task]):
await asyncio.gather(*tasks)

def _run_single(self, task):
async def run():
return await task

return asyncio.run(run())
"""Sleep all rollout replica instances.
TODO: MOPD -- sleep each of the teacher model managers.
"""
await _run_all([manager.sleep() for manager in [self.teacher_model_manager]])
6 changes: 3 additions & 3 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,11 +845,11 @@ def init_workers(self):

# initialize teacher loop manager
if self.use_teacher_policy:
from verl.experimental.teacher_loop import TeacherModelManager
from verl.experimental.teacher_loop import MultiTeacherModelManager

teacher_resource_pool = self.resource_pool_manager.get_resource_pool(Role.TeacherModel)
self.teacher_model_manager = TeacherModelManager(
config=self.config.distillation,
self.teacher_model_manager = MultiTeacherModelManager(
config=self.config,
resource_pool=teacher_resource_pool,
)
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(self.config.distillation)
Expand Down
Loading