[trainer,cfg,rollout,algo] feat: (MOPD, 1/2) Multi-Teacher Model, Server Managers and Config#5834
[trainer,cfg,rollout,algo] feat: (MOPD, 1/2) Multi-Teacher Model, Server Managers and Config#5834JacobHelwig wants to merge 22 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a MultiTeacherModelManager to support multiple teacher models during distillation, refactoring the existing TeacherModelManager and AsyncTeacherLLMServerManager to accommodate this architecture. Key changes include moving configuration and tokenizer management into the server manager and introducing helper functions for async task execution. Review feedback identifies a potential issue with mixing auto_await and asyncio.run in compute_logprobs, incorrect type hints for coroutines in _run_all, and performance risks associated with redundant tokenizer loading in the AsyncTeacherLLMServerManager constructor.
| def compute_logprobs(self, data): | ||
| self.wake_up() | ||
| try: | ||
| return self._run_single(self.server_manager.compute_teacher_logprobs_batch(data)) | ||
| return _run_single(self.server_manager.compute_teacher_logprobs_batch(data)) | ||
| finally: | ||
| self.sleep() |
There was a problem hiding this comment.
The compute_logprobs method is synchronous, but it calls self.wake_up() and self.sleep(), which are decorated with @auto_await. When called synchronously, @auto_await typically runs the coroutine to completion using the current event loop or by creating a new one. However, _run_single explicitly uses asyncio.run(), which creates a new event loop and will fail if an event loop is already running in the same thread (which might happen if auto_await left a loop active).
More importantly, mixing auto_await for some calls and asyncio.run (via _run_single) for others in the same synchronous method is fragile and inefficient. It is better to wrap the entire sequence of operations into a single async method and run that once.
| def compute_logprobs(self, data): | |
| self.wake_up() | |
| try: | |
| return self._run_single(self.server_manager.compute_teacher_logprobs_batch(data)) | |
| return _run_single(self.server_manager.compute_teacher_logprobs_batch(data)) | |
| finally: | |
| self.sleep() | |
| def compute_logprobs(self, data): | |
| async def _compute(): | |
| await self.wake_up() | |
| try: | |
| return await self.server_manager.compute_teacher_logprobs_batch(data) | |
| finally: | |
| await self.sleep() | |
| return _run_single(_compute()) |
| async def _run_all(tasks: list[asyncio.Task]): | ||
| await asyncio.gather(*tasks) |
There was a problem hiding this comment.
The type hint for tasks in _run_all is list[asyncio.Task], but the function is called with a list of coroutines (e.g., in TeacherModelManager.wake_up). While asyncio.gather accepts coroutines, the type hint is technically incorrect and might be misleading for static analysis tools.
| async def _run_all(tasks: list[asyncio.Task]): | |
| await asyncio.gather(*tasks) | |
| async def _run_all(tasks: list): | |
| await asyncio.gather(*tasks) |
| model_config = HFModelConfig(path=teacher_model_config.model_path) | ||
| text_tokenizer = model_config.tokenizer | ||
| if model_config.tokenizer is None: | ||
| raise ValueError(f"Tokenizer is required for teacher model {teacher_model_config.model_path}") | ||
| self.pad_token_id = text_tokenizer.pad_token_id |
There was a problem hiding this comment.
Initializing HFModelConfig and accessing model_config.tokenizer inside __init__ may trigger redundant and expensive tokenizer loading every time an AsyncTeacherLLMServerManager is instantiated. Since this manager is initialized in the AgentLoopWorker, which can be numerous, this could lead to significant overhead and memory pressure. It is recommended to pass the pad_token_id directly or ensure the tokenizer is cached.
12f4b5b to
15e2efe
Compare
|
We should merge this PR with config change in #5774 |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request refactors the distillation framework to support multi-teacher configurations. Key changes include the introduction of a MultiTeacherModelManager to coordinate multiple teacher instances, updates to the DistillationConfig and DistillationTeacherModelConfig schemas to support a dictionary of teacher models, and the addition of validation logic for teacher context lengths and inference engine parameters. Review feedback highlighted a concurrency risk when using asyncio.run() in an existing event loop, an indentation error in the teacher model resolution logic that would break multi-teacher support, and an initialization issue where the manager might use an unconfigured singular teacher field instead of the resolved configuration.
| def _run_single(task): | ||
| async def run(): | ||
| return await task | ||
|
|
||
| return asyncio.run(run()) |
There was a problem hiding this comment.
Using asyncio.run() inside _run_single is dangerous because it will raise a RuntimeError if an event loop is already running in the current thread. In the context of the Ray driver or actors where verl typically runs, an event loop is often already present. It is safer to use the @auto_await decorator which handles both sync and async calling contexts correctly by using run_until_complete when a loop is already running.
References
- Avoid using asyncio.run() in environments where an event loop might already be running; use safer alternatives like run_until_complete or auto-awaiting wrappers.
verl/workers/config/distillation.py
Outdated
| for model_name, teacher_model in self.teacher_models.items(): | ||
| teacher_model = omega_conf_to_dataclass(teacher_model, dataclass_type=DistillationTeacherModelConfig) | ||
| if teacher_model.is_configured(is_multi=True): | ||
| teacher_models[model_name] = teacher_model | ||
| return teacher_models |
There was a problem hiding this comment.
The return teacher_models statement is incorrectly indented inside the for loop. This causes the method to return after processing only the first teacher model entry, which breaks support for multiple teacher models (even if currently limited to one in the runtime path, the configuration resolution logic is incorrect).
References
- Ensure logic correctly handles all items in a collection before returning, unless an early exit is explicitly intended.
| """TODO: MOPD -- split resource pool across teachers and init one TeacherModelManager per teacher.""" | ||
| self.teacher_model_manager = TeacherModelManager( | ||
| distillation_config=self.distillation_config, | ||
| teacher_model_config=self.distillation_config.teacher_model, |
There was a problem hiding this comment.
The MultiTeacherModelManager is passing self.distillation_config.teacher_model directly to the TeacherModelManager. However, if the user has configured the teacher(s) using the new teacher_models (plural) dictionary, the singular teacher_model field will be empty (default), leading to initialization failures. You should use self.distillation_config.get_single_teacher_model() to retrieve the correctly resolved teacher configuration.
References
- Always use resolved configuration objects or helper methods that account for multiple configuration sources (e.g., singular vs plural fields).
| # multi-teacher configs | ||
| teacher_models: {} | ||
|
|
||
| # single-teacher config |
There was a problem hiding this comment.
single-teacher config should be absorbed into teacher_models, which is a special case of MOPD with only one teacher.
There was a problem hiding this comment.
We can place single-teacher under teacher_models with a special key (e.g default)?
What does this PR do?
Classes and configs for managing multiple sets of teacher models and servers for multi-teacher OPD. This PR only supports a single teacher.
Testing
Tests to demonstrate no regression in single teacher OPD.
Script
GSM8k eval acc
GSM8k train acc
GSM8k distillation loss
Design & Code Changes
MultiTeacherModelManager: new class that manages multipleTeacherModelManagerfor MOPD.AsyncTeacherLLMServerManager: add an attributeserver_manager: AsyncLLMServerManager. Previously,AsyncTeacherLLMServerManagerinherited fromAsyncLLMServerManager, but for MOPD, it will manage oneAsyncLLMServerManagerper teacher.teacher_modelentry for single teacher OPD.For the multi-teacher training script, teacher model args will be specified as: