[trainer,rollout,algo] feat: (MOPD, 1/3) Multi-Teacher Model and Server Managers#5834
[trainer,rollout,algo] feat: (MOPD, 1/3) Multi-Teacher Model and Server Managers#5834JacobHelwig wants to merge 5 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a MultiTeacherModelManager to support multiple teacher models during distillation, refactoring the existing TeacherModelManager and AsyncTeacherLLMServerManager to accommodate this architecture. Key changes include moving configuration and tokenizer management into the server manager and introducing helper functions for async task execution. Review feedback identifies a potential issue with mixing auto_await and asyncio.run in compute_logprobs, incorrect type hints for coroutines in _run_all, and performance risks associated with redundant tokenizer loading in the AsyncTeacherLLMServerManager constructor.
| def compute_logprobs(self, data): | ||
| self.wake_up() | ||
| try: | ||
| return self._run_single(self.server_manager.compute_teacher_logprobs_batch(data)) | ||
| return _run_single(self.server_manager.compute_teacher_logprobs_batch(data)) | ||
| finally: | ||
| self.sleep() |
There was a problem hiding this comment.
The compute_logprobs method is synchronous, but it calls self.wake_up() and self.sleep(), which are decorated with @auto_await. When called synchronously, @auto_await typically runs the coroutine to completion using the current event loop or by creating a new one. However, _run_single explicitly uses asyncio.run(), which creates a new event loop and will fail if an event loop is already running in the same thread (which might happen if auto_await left a loop active).
More importantly, mixing auto_await for some calls and asyncio.run (via _run_single) for others in the same synchronous method is fragile and inefficient. It is better to wrap the entire sequence of operations into a single async method and run that once.
| def compute_logprobs(self, data): | |
| self.wake_up() | |
| try: | |
| return self._run_single(self.server_manager.compute_teacher_logprobs_batch(data)) | |
| return _run_single(self.server_manager.compute_teacher_logprobs_batch(data)) | |
| finally: | |
| self.sleep() | |
| def compute_logprobs(self, data): | |
| async def _compute(): | |
| await self.wake_up() | |
| try: | |
| return await self.server_manager.compute_teacher_logprobs_batch(data) | |
| finally: | |
| await self.sleep() | |
| return _run_single(_compute()) |
| async def _run_all(tasks: list[asyncio.Task]): | ||
| await asyncio.gather(*tasks) |
There was a problem hiding this comment.
The type hint for tasks in _run_all is list[asyncio.Task], but the function is called with a list of coroutines (e.g., in TeacherModelManager.wake_up). While asyncio.gather accepts coroutines, the type hint is technically incorrect and might be misleading for static analysis tools.
| async def _run_all(tasks: list[asyncio.Task]): | |
| await asyncio.gather(*tasks) | |
| async def _run_all(tasks: list): | |
| await asyncio.gather(*tasks) |
| model_config = HFModelConfig(path=teacher_model_config.model_path) | ||
| text_tokenizer = model_config.tokenizer | ||
| if model_config.tokenizer is None: | ||
| raise ValueError(f"Tokenizer is required for teacher model {teacher_model_config.model_path}") | ||
| self.pad_token_id = text_tokenizer.pad_token_id |
There was a problem hiding this comment.
Initializing HFModelConfig and accessing model_config.tokenizer inside __init__ may trigger redundant and expensive tokenizer loading every time an AsyncTeacherLLMServerManager is instantiated. Since this manager is initialized in the AgentLoopWorker, which can be numerous, this could lead to significant overhead and memory pressure. It is recommended to pass the pad_token_id directly or ensure the tokenizer is cached.
12f4b5b to
15e2efe
Compare
What does this PR do?
Classes for managing multiple sets of teacher models and servers for multi-teacher OPD. This PR only supports a single teacher.
Testing
Tests to demonstrate no regression in single teacher OPD.
Script
GSM8k eval acc
GSM8k train acc
GSM8k distillation loss
Design & Code Changes
MultiTeacherModelManager: new class that manages multipleTeacherModelManagerfor MOPD.AsyncTeacherLLMServerManager: add an attributeserver_manager: AsyncLLMServerManager. Previously,AsyncTeacherLLMServerManagerinherited fromAsyncLLMServerManager, but for MOPD, it will manage oneAsyncLLMServerManagerper teacher.