-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[trainer,rollout,algo] feat: (MOPD, 1/3) Multi-Teacher Model and Server Managers #5834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
504581f
5634e8c
15e2efe
c617ab6
2a0d681
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,33 +28,45 @@ | |||||||||||||||||||||||||||||||||
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @auto_await | ||||||||||||||||||||||||||||||||||
| async def _run_all(tasks: list[asyncio.Task]): | ||||||||||||||||||||||||||||||||||
| await asyncio.gather(*tasks) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+32
to
+33
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type hint for
Suggested change
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _run_single(task): | ||||||||||||||||||||||||||||||||||
| async def run(): | ||||||||||||||||||||||||||||||||||
| return await task | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return asyncio.run(run()) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class TeacherModelManager: | ||||||||||||||||||||||||||||||||||
| """Teacher model manager.""" | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||
| config: DictConfig, | ||||||||||||||||||||||||||||||||||
| resource_pool: RayResourcePool = None, | ||||||||||||||||||||||||||||||||||
| teacher_model_config: DistillationTeacherModelConfig, | ||||||||||||||||||||||||||||||||||
| resource_pool: RayResourcePool, | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| Initialize the teacher model manager. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||
| config (DictConfig): Teacher model configuration. | ||||||||||||||||||||||||||||||||||
| resource_pool (RayResourcePool, optional): Resource pool. Defaults to None. | ||||||||||||||||||||||||||||||||||
| teacher_model_config (DistillationTeacherModelConfig): Teacher model configuration. | ||||||||||||||||||||||||||||||||||
| resource_pool (RayResourcePool): Resource pool. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Need dataclass conversion for max_logprobs handling in post_init | ||||||||||||||||||||||||||||||||||
| self.config: DistillationConfig = omega_conf_to_dataclass(config) | ||||||||||||||||||||||||||||||||||
| self.teacher_model_config = teacher_model_config | ||||||||||||||||||||||||||||||||||
| self.resource_pool = resource_pool | ||||||||||||||||||||||||||||||||||
| self._initialize_llm_servers() | ||||||||||||||||||||||||||||||||||
| self._initialize_async_server_manager() | ||||||||||||||||||||||||||||||||||
| self._initialize_load_balancer_handle() | ||||||||||||||||||||||||||||||||||
| self._initialize_router() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| self.sleep() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _initialize_llm_servers(self): | ||||||||||||||||||||||||||||||||||
| teacher_model_config: DistillationTeacherModelConfig = self.config.teacher_model | ||||||||||||||||||||||||||||||||||
| teacher_model_config = self.teacher_model_config | ||||||||||||||||||||||||||||||||||
| teacher_world_size = ( | ||||||||||||||||||||||||||||||||||
| teacher_model_config.inference.tensor_model_parallel_size | ||||||||||||||||||||||||||||||||||
| * teacher_model_config.inference.data_parallel_size | ||||||||||||||||||||||||||||||||||
|
|
@@ -71,10 +83,6 @@ def _initialize_llm_servers(self): | |||||||||||||||||||||||||||||||||
| rollout_config = teacher_model_config.inference | ||||||||||||||||||||||||||||||||||
| model_config = HFModelConfig(path=teacher_model_config.model_path) | ||||||||||||||||||||||||||||||||||
| self.tokenizer = model_config.get_processor() | ||||||||||||||||||||||||||||||||||
| text_tokenizer = model_config.tokenizer | ||||||||||||||||||||||||||||||||||
| if model_config.tokenizer is None: | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"Tokenizer is required for teacher model {teacher_model_config.model_path}") | ||||||||||||||||||||||||||||||||||
| self.pad_token_id = text_tokenizer.pad_token_id | ||||||||||||||||||||||||||||||||||
| self.rollout_replicas = [ | ||||||||||||||||||||||||||||||||||
| rollout_replica_class( | ||||||||||||||||||||||||||||||||||
| replica_rank=replica_rank, | ||||||||||||||||||||||||||||||||||
|
|
@@ -88,31 +96,23 @@ def _initialize_llm_servers(self): | |||||||||||||||||||||||||||||||||
| if self.resource_pool: | ||||||||||||||||||||||||||||||||||
| split_resource_pools = split_resource_pool(self.resource_pool, split_size=teacher_world_size) | ||||||||||||||||||||||||||||||||||
| assert len(split_resource_pools) == len(self.rollout_replicas) | ||||||||||||||||||||||||||||||||||
| self._run_all( | ||||||||||||||||||||||||||||||||||
| _run_all( | ||||||||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||||||||
| server.init_colocated(resource_pool) | ||||||||||||||||||||||||||||||||||
| for server, resource_pool in zip(self.rollout_replicas, split_resource_pools, strict=True) | ||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| self._run_all([server.init_standalone() for server in self.rollout_replicas]) | ||||||||||||||||||||||||||||||||||
| _run_all([server.init_standalone() for server in self.rollout_replicas]) | ||||||||||||||||||||||||||||||||||
| self.server_handles = [server._server_handle for server in self.rollout_replicas] | ||||||||||||||||||||||||||||||||||
| self.server_addresses = [server._server_address for server in self.rollout_replicas] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _initialize_async_server_manager(self): | ||||||||||||||||||||||||||||||||||
| def _initialize_load_balancer_handle(self): | ||||||||||||||||||||||||||||||||||
| from verl.experimental.agent_loop.agent_loop import GlobalRequestLoadBalancer | ||||||||||||||||||||||||||||||||||
| from verl.experimental.teacher_loop.teacher_manager import AsyncTeacherLLMServerManager | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| self.load_balancer_handle = GlobalRequestLoadBalancer.remote( | ||||||||||||||||||||||||||||||||||
| server_actor_ids=self.server_addresses, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| self.server_manager = AsyncTeacherLLMServerManager( | ||||||||||||||||||||||||||||||||||
| config=self.config, | ||||||||||||||||||||||||||||||||||
| servers=list(zip(self.server_addresses, self.server_handles, strict=True)), | ||||||||||||||||||||||||||||||||||
| load_balancer_handle=self.load_balancer_handle, | ||||||||||||||||||||||||||||||||||
| distillation_config=self.config, | ||||||||||||||||||||||||||||||||||
| pad_token_id=self.pad_token_id, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _initialize_router(self): | ||||||||||||||||||||||||||||||||||
| worker_urls = [f"http://{server_address}" for server_address in self.server_addresses] | ||||||||||||||||||||||||||||||||||
|
|
@@ -124,29 +124,86 @@ def _initialize_router(self): | |||||||||||||||||||||||||||||||||
| def get_router_address(self): | ||||||||||||||||||||||||||||||||||
| return self.router_address | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @auto_await | ||||||||||||||||||||||||||||||||||
| async def wake_up(self): | ||||||||||||||||||||||||||||||||||
| """Wake up all rollout replica instances.""" | ||||||||||||||||||||||||||||||||||
| await _run_all([replica.wake_up() for replica in self.rollout_replicas]) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @auto_await | ||||||||||||||||||||||||||||||||||
| async def sleep(self): | ||||||||||||||||||||||||||||||||||
| """Sleep all rollout replica instances.""" | ||||||||||||||||||||||||||||||||||
| await _run_all([replica.sleep() for replica in self.rollout_replicas]) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class MultiTeacherModelManager: | ||||||||||||||||||||||||||||||||||
| """Multi Teacher model manager.""" | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||
| config: DictConfig, | ||||||||||||||||||||||||||||||||||
| resource_pool: RayResourcePool, | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| Initialize the teacher model manager. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||
| config (DictConfig): Full configuration. | ||||||||||||||||||||||||||||||||||
| resource_pool (RayResourcePool): Resource pool. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Need dataclass conversion for max_logprobs handling in post_init | ||||||||||||||||||||||||||||||||||
| self.config = config | ||||||||||||||||||||||||||||||||||
| self.distillation_config: DistillationConfig = omega_conf_to_dataclass(config.distillation) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| self.resource_pool = resource_pool | ||||||||||||||||||||||||||||||||||
| self._initialize_teacher_model_managers() | ||||||||||||||||||||||||||||||||||
| self._initialize_async_server_manager() | ||||||||||||||||||||||||||||||||||
| self.sleep() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _initialize_teacher_model_managers(self): | ||||||||||||||||||||||||||||||||||
| """TODO: MOPD -- split resource pool across teachers and init one TeacherModelManager per teacher.""" | ||||||||||||||||||||||||||||||||||
| self.teacher_model_manager = TeacherModelManager( | ||||||||||||||||||||||||||||||||||
| teacher_model_config=self.distillation_config.teacher_model, resource_pool=self.resource_pool | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| self.server_addresses = self.teacher_model_manager.server_addresses | ||||||||||||||||||||||||||||||||||
| self.server_handles = self.teacher_model_manager.server_handles | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _initialize_async_server_manager(self): | ||||||||||||||||||||||||||||||||||
| """TODO: MOPD -- balancers/servers become a dict (one per teacher)""" | ||||||||||||||||||||||||||||||||||
| from verl.experimental.teacher_loop.teacher_manager import AsyncTeacherLLMServerManager | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| self.load_balancer_handle = self.teacher_model_manager.load_balancer_handle | ||||||||||||||||||||||||||||||||||
| servers = list(zip(self.server_addresses, self.server_handles, strict=True)) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # In standalone mode, server manager is initialized in the agent loop | ||||||||||||||||||||||||||||||||||
| if not self.distillation_config.teacher_model.enable_resource_pool: | ||||||||||||||||||||||||||||||||||
| self.server_manager = AsyncTeacherLLMServerManager( | ||||||||||||||||||||||||||||||||||
| config=self.config, | ||||||||||||||||||||||||||||||||||
| servers=servers, | ||||||||||||||||||||||||||||||||||
| load_balancer_handle=self.load_balancer_handle, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def get_router_address(self): | ||||||||||||||||||||||||||||||||||
| """TODO: MOPD -- return dict of router addresses (one per teacher)""" | ||||||||||||||||||||||||||||||||||
| return self.teacher_model_manager.router_address | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def compute_logprobs(self, data): | ||||||||||||||||||||||||||||||||||
| self.wake_up() | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| return self._run_single(self.server_manager.compute_teacher_logprobs_batch(data)) | ||||||||||||||||||||||||||||||||||
| return _run_single(self.server_manager.compute_teacher_logprobs_batch(data)) | ||||||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||||||
| self.sleep() | ||||||||||||||||||||||||||||||||||
|
Comment on lines
190
to
195
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The More importantly, mixing
Suggested change
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @auto_await | ||||||||||||||||||||||||||||||||||
| async def wake_up(self): | ||||||||||||||||||||||||||||||||||
| """Wake up all rollout replica instances.""" | ||||||||||||||||||||||||||||||||||
| await self._run_all([replica.wake_up() for replica in self.rollout_replicas]) | ||||||||||||||||||||||||||||||||||
| """Wake up all rollout replica instances. | ||||||||||||||||||||||||||||||||||
| TODO: MOPD -- wake up each of the teacher model managers. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| await _run_all([manager.wake_up() for manager in [self.teacher_model_manager]]) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @auto_await | ||||||||||||||||||||||||||||||||||
| async def sleep(self): | ||||||||||||||||||||||||||||||||||
| """Sleep all rollout replica instances.""" | ||||||||||||||||||||||||||||||||||
| await self._run_all([replica.sleep() for replica in self.rollout_replicas]) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @auto_await | ||||||||||||||||||||||||||||||||||
| async def _run_all(self, tasks: list[asyncio.Task]): | ||||||||||||||||||||||||||||||||||
| await asyncio.gather(*tasks) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _run_single(self, task): | ||||||||||||||||||||||||||||||||||
| async def run(): | ||||||||||||||||||||||||||||||||||
| return await task | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return asyncio.run(run()) | ||||||||||||||||||||||||||||||||||
| """Sleep all rollout replica instances. | ||||||||||||||||||||||||||||||||||
| TODO: MOPD -- sleep each of the teacher model managers. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| await _run_all([manager.sleep() for manager in [self.teacher_model_manager]]) | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initializing
HFModelConfigand accessingmodel_config.tokenizerinside__init__may trigger redundant and expensive tokenizer loading every time anAsyncTeacherLLMServerManageris instantiated. Since this manager is initialized in theAgentLoopWorker, which can be numerous, this could lead to significant overhead and memory pressure. It is recommended to pass thepad_token_iddirectly or ensure the tokenizer is cached.