diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 1457ac6fccc..fa747427ffa 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -32,7 +32,7 @@ from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config from verl.experimental.agent_loop.utils import resolve_config_path -from verl.experimental.teacher_loop import TeacherModelManager +from verl.experimental.teacher_loop import MultiTeacherModelManager from verl.protocol import DataProto from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup from verl.trainer.distillation import is_distillation_enabled @@ -454,8 +454,6 @@ def __init__( config, teacher_servers, load_balancer_handle=teacher_load_balancer_handle, - distillation_config=self.distillation_config, - pad_token_id=self.model_config.tokenizer.pad_token_id, ) else: self.teacher_server_manager = None @@ -999,7 +997,8 @@ class AgentLoopManager: config (DictConfig): whole config for main entrypoint. worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode. rollout_resource_pool (RayResourcePool): Resource pool for hybrid mode, only used by TensorRT-LLM. - teacher_model_manager (TeacherModelManager): Manager for streaming teacher computation, used for distillation. + teacher_model_manager (MultiTeacherModelManager): Manager for streaming teacher computation, used for + distillation. reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation. """ @@ -1008,7 +1007,7 @@ def __init__( config: DictConfig, worker_group: RayWorkerGroup = None, rollout_resource_pool: RayResourcePool = None, - teacher_model_manager: TeacherModelManager = None, + teacher_model_manager: MultiTeacherModelManager = None, reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, ): self.config = config @@ -1044,7 +1043,7 @@ async def create( worker_group: RayWorkerGroup = None, rollout_resource_pool: RayResourcePool = None, reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, - teacher_model_manager: TeacherModelManager = None, + teacher_model_manager: MultiTeacherModelManager = None, ): """Create agent loop manager.""" instance = cls(config, worker_group, rollout_resource_pool, teacher_model_manager, reward_loop_worker_handles) diff --git a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py index f1df5cdf17a..32795028826 100644 --- a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py +++ b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py @@ -26,7 +26,7 @@ AsyncLLMServerManager, TokenOutput, ) -from verl.experimental.teacher_loop import TeacherModelManager +from verl.experimental.teacher_loop import MultiTeacherModelManager from verl.protocol import DataProto from verl.single_controller.ray import RayResourcePool, RayWorkerGroup from verl.utils.ray_utils import auto_await @@ -152,7 +152,7 @@ def __init__( config: DictConfig, worker_group: RayWorkerGroup = None, rollout_resource_pool: RayResourcePool = None, - teacher_model_manager: TeacherModelManager = None, + teacher_model_manager: MultiTeacherModelManager = None, reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, ): self.agent_loop_workers_class = FullyAsyncAgentLoopWorker diff --git a/verl/experimental/teacher_loop/__init__.py b/verl/experimental/teacher_loop/__init__.py index 9e87dd88d90..16e6c29887f 100644 --- a/verl/experimental/teacher_loop/__init__.py +++ b/verl/experimental/teacher_loop/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .teacher_model import TeacherModelManager +from .teacher_model import MultiTeacherModelManager -__all__ = ["TeacherModelManager"] +__all__ = ["MultiTeacherModelManager"] diff --git a/verl/experimental/teacher_loop/teacher_manager.py b/verl/experimental/teacher_loop/teacher_manager.py index 28a89262e96..608d7b2b3c1 100644 --- a/verl/experimental/teacher_loop/teacher_manager.py +++ b/verl/experimental/teacher_loop/teacher_manager.py @@ -25,21 +25,26 @@ from verl.protocol import DataProto from verl.utils.config import omega_conf_to_dataclass from verl.utils.tokenizer import normalize_token_ids -from verl.workers.config import DistillationConfig, DistillationLossConfig +from verl.workers.config import ( + DistillationConfig, + DistillationLossConfig, + DistillationTeacherModelConfig, + HFModelConfig, +) def _get_teacher_sampling_params( - distillation_config: DistillationConfig, + teacher_config: DistillationTeacherModelConfig, distillation_loss_config: DistillationLossConfig, ) -> dict[str, Any]: """Get sampling parameters for teacher model when computing log probabilities for distillation.""" - if distillation_config.teacher_model.inference.temperature != 1.0: + if teacher_config.inference.temperature != 1.0: raise NotImplementedError("vLLM does not support temperature for prompt_logprobs.") num_logprobs = distillation_loss_config.topk if distillation_loss_config.loss_settings.use_topk else 0 return { "max_tokens": 1, - "temperature": distillation_config.teacher_model.inference.temperature, + "temperature": teacher_config.inference.temperature, "prompt_logprobs": num_logprobs, } @@ -85,24 +90,42 @@ def _unpad_teacher_inputs(data: DataProto) -> tuple[list[int], int, int]: return sequence_ids, valid_prompt_length, valid_response_length -class AsyncTeacherLLMServerManager(AsyncLLMServerManager): - """Teacher-specific async client used for distillation logprob computation.""" +class AsyncTeacherLLMServerManager: + """Teacher-specific async client used for distillation logprob computation. + TODO: MOPD -- servers and load_balance_handle become a dict (one per teacher) + """ def __init__( self, config: DictConfig, servers: list[tuple[str, ray.actor.ActorHandle]], load_balancer_handle: ray.actor.ActorHandle, - distillation_config: DictConfig | DistillationConfig, - pad_token_id: int, ): - super().__init__(config=config, servers=servers, load_balancer_handle=load_balancer_handle) - if isinstance(distillation_config, DistillationConfig): - self.distillation_config = distillation_config - else: - self.distillation_config: DistillationConfig = omega_conf_to_dataclass(distillation_config) + self.distillation_config: DistillationConfig = omega_conf_to_dataclass(config.distillation) self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss - self.pad_token_id = pad_token_id + self.teacher_config: DistillationTeacherModelConfig = self.distillation_config.teacher_model + + # Get pad token ID + teacher_model_config: DistillationTeacherModelConfig = self.distillation_config.teacher_model + model_config = HFModelConfig(path=teacher_model_config.model_path) + text_tokenizer = model_config.tokenizer + if model_config.tokenizer is None: + raise ValueError(f"Tokenizer is required for teacher model {teacher_model_config.model_path}") + self.pad_token_id = text_tokenizer.pad_token_id + self._initialize_teacher_server_managers( + config=config, servers=servers, load_balancer_handle=load_balancer_handle + ) + + def _initialize_teacher_server_managers( + self, + config: DictConfig, + servers: list[tuple[str, ray.actor.ActorHandle]], + load_balancer_handle: ray.actor.ActorHandle, + ): + # TODO: MOPD -- balancers/servers become a dict (one per teacher) + self.server_manager = AsyncLLMServerManager( + config=config, servers=servers, load_balancer_handle=load_balancer_handle + ) async def compute_teacher_logprobs_single( self, @@ -111,10 +134,11 @@ async def compute_teacher_logprobs_single( ) -> tuple[torch.Tensor, torch.Tensor]: """Compute teacher log probabilities for a single unpadded sequence.""" multi_modal_data = multi_modal_data or {} - teacher_output = await self.generate( + # TODO: MOPD -- select server manager from server manager dict based on example task key + teacher_output = await self.server_manager.generate( request_id=uuid4().hex, prompt_ids=sequence_ids, - sampling_params=_get_teacher_sampling_params(self.distillation_config, self.distillation_loss_config), + sampling_params=_get_teacher_sampling_params(self.teacher_config, self.distillation_loss_config), image_data=multi_modal_data.get("images"), video_data=multi_modal_data.get("videos"), ) diff --git a/verl/experimental/teacher_loop/teacher_model.py b/verl/experimental/teacher_loop/teacher_model.py index 966efa969ab..26119df6f76 100644 --- a/verl/experimental/teacher_loop/teacher_model.py +++ b/verl/experimental/teacher_loop/teacher_model.py @@ -28,33 +28,45 @@ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +@auto_await +async def _run_all(tasks: list[asyncio.Task]): + await asyncio.gather(*tasks) + + +def _run_single(task): + async def run(): + return await task + + return asyncio.run(run()) + + class TeacherModelManager: """Teacher model manager.""" def __init__( self, - config: DictConfig, - resource_pool: RayResourcePool = None, + teacher_model_config: DistillationTeacherModelConfig, + resource_pool: RayResourcePool, ): """ Initialize the teacher model manager. Args: - config (DictConfig): Teacher model configuration. - resource_pool (RayResourcePool, optional): Resource pool. Defaults to None. + teacher_model_config (DistillationTeacherModelConfig): Teacher model configuration. + resource_pool (RayResourcePool): Resource pool. """ # Need dataclass conversion for max_logprobs handling in post_init - self.config: DistillationConfig = omega_conf_to_dataclass(config) + self.teacher_model_config = teacher_model_config self.resource_pool = resource_pool self._initialize_llm_servers() - self._initialize_async_server_manager() + self._initialize_load_balancer_handle() self._initialize_router() self.sleep() def _initialize_llm_servers(self): - teacher_model_config: DistillationTeacherModelConfig = self.config.teacher_model + teacher_model_config = self.teacher_model_config teacher_world_size = ( teacher_model_config.inference.tensor_model_parallel_size * teacher_model_config.inference.data_parallel_size @@ -71,10 +83,6 @@ def _initialize_llm_servers(self): rollout_config = teacher_model_config.inference model_config = HFModelConfig(path=teacher_model_config.model_path) self.tokenizer = model_config.get_processor() - text_tokenizer = model_config.tokenizer - if model_config.tokenizer is None: - raise ValueError(f"Tokenizer is required for teacher model {teacher_model_config.model_path}") - self.pad_token_id = text_tokenizer.pad_token_id self.rollout_replicas = [ rollout_replica_class( replica_rank=replica_rank, @@ -88,31 +96,23 @@ def _initialize_llm_servers(self): if self.resource_pool: split_resource_pools = split_resource_pool(self.resource_pool, split_size=teacher_world_size) assert len(split_resource_pools) == len(self.rollout_replicas) - self._run_all( + _run_all( [ server.init_colocated(resource_pool) for server, resource_pool in zip(self.rollout_replicas, split_resource_pools, strict=True) ] ) else: - self._run_all([server.init_standalone() for server in self.rollout_replicas]) + _run_all([server.init_standalone() for server in self.rollout_replicas]) self.server_handles = [server._server_handle for server in self.rollout_replicas] self.server_addresses = [server._server_address for server in self.rollout_replicas] - def _initialize_async_server_manager(self): + def _initialize_load_balancer_handle(self): from verl.experimental.agent_loop.agent_loop import GlobalRequestLoadBalancer - from verl.experimental.teacher_loop.teacher_manager import AsyncTeacherLLMServerManager self.load_balancer_handle = GlobalRequestLoadBalancer.remote( server_actor_ids=self.server_addresses, ) - self.server_manager = AsyncTeacherLLMServerManager( - config=self.config, - servers=list(zip(self.server_addresses, self.server_handles, strict=True)), - load_balancer_handle=self.load_balancer_handle, - distillation_config=self.config, - pad_token_id=self.pad_token_id, - ) def _initialize_router(self): worker_urls = [f"http://{server_address}" for server_address in self.server_addresses] @@ -124,29 +124,86 @@ def _initialize_router(self): def get_router_address(self): return self.router_address + @auto_await + async def wake_up(self): + """Wake up all rollout replica instances.""" + await _run_all([replica.wake_up() for replica in self.rollout_replicas]) + + @auto_await + async def sleep(self): + """Sleep all rollout replica instances.""" + await _run_all([replica.sleep() for replica in self.rollout_replicas]) + + +class MultiTeacherModelManager: + """Multi Teacher model manager.""" + + def __init__( + self, + config: DictConfig, + resource_pool: RayResourcePool, + ): + """ + Initialize the teacher model manager. + + Args: + config (DictConfig): Full configuration. + resource_pool (RayResourcePool): Resource pool. + """ + + # Need dataclass conversion for max_logprobs handling in post_init + self.config = config + self.distillation_config: DistillationConfig = omega_conf_to_dataclass(config.distillation) + + self.resource_pool = resource_pool + self._initialize_teacher_model_managers() + self._initialize_async_server_manager() + self.sleep() + + def _initialize_teacher_model_managers(self): + """TODO: MOPD -- split resource pool across teachers and init one TeacherModelManager per teacher.""" + self.teacher_model_manager = TeacherModelManager( + teacher_model_config=self.distillation_config.teacher_model, resource_pool=self.resource_pool + ) + self.server_addresses = self.teacher_model_manager.server_addresses + self.server_handles = self.teacher_model_manager.server_handles + + def _initialize_async_server_manager(self): + """TODO: MOPD -- balancers/servers become a dict (one per teacher)""" + from verl.experimental.teacher_loop.teacher_manager import AsyncTeacherLLMServerManager + + self.load_balancer_handle = self.teacher_model_manager.load_balancer_handle + servers = list(zip(self.server_addresses, self.server_handles, strict=True)) + + # In standalone mode, server manager is initialized in the agent loop + if not self.distillation_config.teacher_model.enable_resource_pool: + self.server_manager = AsyncTeacherLLMServerManager( + config=self.config, + servers=servers, + load_balancer_handle=self.load_balancer_handle, + ) + + def get_router_address(self): + """TODO: MOPD -- return dict of router addresses (one per teacher)""" + return self.teacher_model_manager.router_address + def compute_logprobs(self, data): self.wake_up() try: - return self._run_single(self.server_manager.compute_teacher_logprobs_batch(data)) + return _run_single(self.server_manager.compute_teacher_logprobs_batch(data)) finally: self.sleep() @auto_await async def wake_up(self): - """Wake up all rollout replica instances.""" - await self._run_all([replica.wake_up() for replica in self.rollout_replicas]) + """Wake up all rollout replica instances. + TODO: MOPD -- wake up each of the teacher model managers. + """ + await _run_all([manager.wake_up() for manager in [self.teacher_model_manager]]) @auto_await async def sleep(self): - """Sleep all rollout replica instances.""" - await self._run_all([replica.sleep() for replica in self.rollout_replicas]) - - @auto_await - async def _run_all(self, tasks: list[asyncio.Task]): - await asyncio.gather(*tasks) - - def _run_single(self, task): - async def run(): - return await task - - return asyncio.run(run()) + """Sleep all rollout replica instances. + TODO: MOPD -- sleep each of the teacher model managers. + """ + await _run_all([manager.sleep() for manager in [self.teacher_model_manager]]) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 6cb58dc6251..7563458870e 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -845,11 +845,11 @@ def init_workers(self): # initialize teacher loop manager if self.use_teacher_policy: - from verl.experimental.teacher_loop import TeacherModelManager + from verl.experimental.teacher_loop import MultiTeacherModelManager teacher_resource_pool = self.resource_pool_manager.get_resource_pool(Role.TeacherModel) - self.teacher_model_manager = TeacherModelManager( - config=self.config.distillation, + self.teacher_model_manager = MultiTeacherModelManager( + config=self.config, resource_pool=teacher_resource_pool, ) self.distillation_config: DistillationConfig = omega_conf_to_dataclass(self.config.distillation)