diff --git a/.secrets.baseline b/.secrets.baseline index 7059d28bb..f93f324a6 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -133,7 +133,7 @@ "filename": "README.md", "hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5", "is_verified": false, - "line_number": 454 + "line_number": 530 } ], "SLURM.md": [ @@ -561,5 +561,5 @@ } ] }, - "generated_at": "2026-03-02T22:46:56Z" + "generated_at": "2026-03-14T00:43:09Z" } diff --git a/README.md b/README.md index 3b533a9bc..4d12f592a 100644 --- a/README.md +++ b/README.md @@ -298,6 +298,82 @@ curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids! - Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.). - Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR. +### TeacherDistillationEnv follow-up + +The follow-up teacher environment uses a dedicated teacher server config and +attaches teacher prompt logprobs before the group is sent to the API. + +Teacher config shape: + +```python +TeacherDistillationConfig( + teacher_enabled=True, + teacher_top_k=8, +) +``` + +Teacher server configs are passed separately at init, just like the primary +`server_configs`: + +```python +env = MyTeacherEnv( + config=env_config, + server_configs=student_server_configs, + teacher_server_configs=[ + APIServerConfig( + base_url="http://localhost:9003/v1", + model_name="Qwen/Qwen3-30B-A3B-Instruct-2507", + api_key="", + server_type="vllm", + tokenizer_name="Qwen/Qwen3-30B-A3B-Instruct-2507", + ) + ], +) +``` + +You can either: + +- build a teacher-enabled env by mixing `TeacherDistillationEnv` into an existing + `BaseEnv`-derived env such as `GSM8kEnv`, or +- subclass `TeacherDistillationEnv` directly and implement the usual environment + methods yourself. + +In both cases, `TeacherDistillationEnv` still assumes the normal `BaseEnv` +runtime contract: tokenized rollouts, `ScoredDataGroup` payloads, and the +standard `handle_send_to_api(...)` transport path. + +CLI shape: + +```bash +--env.teacher_enabled true \ +--teacher.base_url "http://localhost:9003/v1" \ +--teacher.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \ +--teacher.server_type vllm \ +--env.teacher_top_k 8 +``` + +If `--teacher.model_name` is a deployment alias rather than a tokenizer +identifier, also set `--teacher.tokenizer_name ...` so the env can validate +tokenizer compatibility. + +Scope note: + +- The teacher-aware CLI wiring currently exists for `serve`. +- If `teacher_enabled=True`, the generic `process` and `evaluate` commands will + fail loudly at env construction time unless you instantiate the env yourself + and pass `teacher_server_configs=...`. + +Tokenizer requirement: + +- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary. +- If the tokenizers do not match, `TeacherDistillationEnv` raises an error instead of attempting token conversion. + +Why same-tokenizer is required: + +- `distill_token_ids` are consumed as student-vocabulary IDs by the trainer. +- If the teacher uses a different vocabulary, the same integer token ID refers to different text on the teacher and student sides. +- A decode/re-tokenize/remap pipeline is not a safe drop-in fix because it changes both token positions and token identities, which breaks the exact per-position token supervision that the current distillation loss assumes. + --- ## Testing and Debugging Tools diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index fecc5828a..98f686828 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -199,18 +199,14 @@ def resolve_openai_configs( f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" ) from e elif isinstance(default_server_configs, APIServerConfig): - # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline - logger.info( - "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." - ) + logger.info("Using single OpenAI server configuration.") try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( - f"Error creating final OpenAI configuration from merged settings: {e}\n" - f"Merged Dict: {openai_config_dict}" + f"Error creating final OpenAI configuration: {e}" ) from e - server_configs = final_openai_config + server_configs = [final_openai_config] elif isinstance(default_server_configs, ServerBaseline): # Pure ServerBaseline (not APIServerConfig) - no CLI overrides possible logger.info("Using ServerBaseline configuration.") @@ -219,26 +215,22 @@ def resolve_openai_configs( logger.info("Using default multi-server configuration (length >= 2).") server_configs = default_server_configs else: - logger.info( - "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." - ) + logger.info("Using single OpenAI server configuration.") try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( - f"Error creating final OpenAI configuration from merged settings: {e}\n" - f"Merged Dict: {openai_config_dict}" + f"Error creating final OpenAI configuration: {e}" ) from e if isinstance(default_server_configs, APIServerConfig): - server_configs = final_openai_config + server_configs = [final_openai_config] elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: logger.warning( f"Unexpected type for default_server_configs: {type(default_server_configs)}. " - f"Proceeding with single OpenAI server configuration based on merged settings." + "Proceeding with single OpenAI server configuration." ) server_configs = [final_openai_config] - return server_configs diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 3c35bebb6..ba272e768 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -281,7 +281,7 @@ async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]: ), "Prompt or input_ids is required for get_logprobs!" top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1))) - top_k = max(1, top_k) + top_k = max(0, top_k) # Use input_ids if provided (from ManagedServer), otherwise tokenize prompt from_prompt_text = False @@ -408,25 +408,22 @@ def resolve_openai_configs( logger.info("Using default multi-server configuration (length >= 2).") server_configs = default_server_configs else: - logger.info( - "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." - ) + logger.info("Using single OpenAI server configuration.") try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( - f"Error creating final OpenAI configuration from merged settings: {e}\n" - f"Merged Dict: {openai_config_dict}" + f"Error creating final OpenAI configuration: {e}" ) from e if isinstance(default_server_configs, APIServerConfig): - server_configs = final_openai_config + server_configs = [final_openai_config] elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: logger.warning( f"Unexpected type for default_server_configs: {type(default_server_configs)}. " - f"Proceeding with single OpenAI server configuration based on merged settings." + "Proceeding with single OpenAI server configuration." ) server_configs = [final_openai_config] diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py new file mode 100644 index 000000000..1f0e21109 --- /dev/null +++ b/atroposlib/envs/teacher_distillation_env.py @@ -0,0 +1,459 @@ +""" +Teacher distillation environment layer. + +This module adds teacher prompt-logprob fetching on top of BaseEnv without +modifying BaseEnv transport behavior. + +This implementation supports same-tokenizer distillation only. The teacher and +student must share the same tokenizer vocabulary so the student's token IDs can +be forwarded directly to the teacher and the returned teacher top-k token IDs +can be looked up directly in the student's logits. +""" + +from __future__ import annotations + +import asyncio +import logging +from abc import ABC +from typing import Any, Dict, List, Optional, Tuple, Union + +import yaml +from pydantic import Field +from pydantic_cli import Cmd +from rich import print as rprint + +from ..utils.cli import ( + extract_namespace, + get_double_dash_flags, + get_prefixed_pydantic_model, + merge_dicts, +) +from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup +from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE +from .server_handling.openai_server import resolve_openai_configs +from .server_handling.server_baseline import APIServerConfig, ServerBaseline +from .server_handling.server_manager import ServerManager, ServerManagerConfig + +logger = logging.getLogger(__name__) + + +class TeacherDistillationConfig(BaseEnvConfig): + teacher_enabled: bool = Field( + default=False, + description="Whether to fetch teacher prompt logprobs for distillation.", + ) + teacher_top_k: int = Field( + default=0, + ge=-1, + description=( + "Number of extra prompt logprobs to fetch beyond the selected token. " + "Use 0 for selected-token-only prompt logprobs and <= -1 to disable " + "teacher fetching." + ), + ) + + +class TeacherDistillationEnv(BaseEnv, ABC): + """ + BaseEnv subclass that enriches scored groups with teacher distillation arrays. + + Distillation payload shape: + - distill_token_ids: [sequence][position][k] (student vocab IDs) + - distill_logprobs: [sequence][position][k] + """ + + env_config_cls = TeacherDistillationConfig + teacher_namespace = "teacher" + + @classmethod + def teacher_config_init( + cls, + ) -> Optional[Union[ServerBaseline, List[APIServerConfig], APIServerConfig]]: + return None + + @classmethod + def _resolve_teacher_server_configs( + cls, + default_teacher_server_configs: Optional[ + Union[ServerBaseline, List[APIServerConfig], APIServerConfig] + ], + yaml_config: Dict[str, Any], + cli_passed_flags: Dict[str, Any], + ) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]: + teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}" + teacher_cli_passed_args = extract_namespace( + cli_passed_flags, teacher_full_prefix + ) + yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {}) + + if ( + default_teacher_server_configs is None + and not teacher_cli_passed_args + and not yaml_teacher_config + ): + return None + + effective_teacher_server_configs = default_teacher_server_configs + if effective_teacher_server_configs is None: + effective_teacher_server_configs = APIServerConfig() + elif isinstance(effective_teacher_server_configs, ServerBaseline) and ( + teacher_cli_passed_args or yaml_teacher_config + ): + effective_teacher_server_configs = APIServerConfig( + **effective_teacher_server_configs.model_dump() + ) + + if ( + isinstance(effective_teacher_server_configs, list) + and len(effective_teacher_server_configs) == 1 + ): + default_teacher_config = effective_teacher_server_configs[0] + else: + default_teacher_config = effective_teacher_server_configs + + if isinstance(yaml_teacher_config, list) and len(yaml_teacher_config) == 1: + yaml_teacher_config = yaml_teacher_config[0] + + if isinstance(default_teacher_config, APIServerConfig) and isinstance( + yaml_teacher_config, dict + ): + teacher_config_dict = merge_dicts( + default_teacher_config.model_dump(), + yaml_teacher_config, + teacher_cli_passed_args, + ) + else: + teacher_config_dict = {} + + teacher_yaml_wrapped = {OPENAI_NAMESPACE: yaml_teacher_config} + teacher_cli_wrapped = { + f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}{key}": value + for key, value in teacher_cli_passed_args.items() + } + return resolve_openai_configs( + default_server_configs=effective_teacher_server_configs, + openai_config_dict=teacher_config_dict, + yaml_config=teacher_yaml_wrapped, + cli_passed_flags=teacher_cli_wrapped, + logger=logger, + ) + + @classmethod + def get_cli_serve_config_cls(cls) -> type: + default_env_config, default_server_configs = cls.config_init() + default_teacher_server_configs = cls.teacher_config_init() + + env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}" + openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}" + teacher_cli_base = get_prefixed_pydantic_model( + APIServerConfig, teacher_full_prefix + ) + + class CliServeConfig( + get_prefixed_pydantic_model(type(default_env_config), env_full_prefix), + get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix), + teacher_cli_base, + ServerManagerConfig, + Cmd, + ): + config: str | None = Field( + default=None, + description="Path to .yaml config file. CLI args override this.", + ) + + def run(self) -> None: + wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" + if ( + getattr(self, wandb_name_attr, None) is None + and cls.name is not None + ): + setattr(self, wandb_name_attr, cls.name) + + if self.config is not None: + with open(self.config, "r") as f: + yaml_config = yaml.safe_load(f) + logger.info("Loaded config from %s", self.config) + else: + yaml_config = {} + + cli_passed_flags = get_double_dash_flags() + + env_config_dict = merge_dicts( + default_env_config.model_dump(), + yaml_config.get(ENV_NAMESPACE, {}), + extract_namespace(cli_passed_flags, env_full_prefix), + ) + + oai_cli_passed_args = extract_namespace( + cli_passed_flags, openai_full_prefix + ) + yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) + + effective_server_configs = default_server_configs + if isinstance(effective_server_configs, ServerBaseline) and ( + oai_cli_passed_args or yaml_oai_config + ): + effective_server_configs = APIServerConfig( + **effective_server_configs.model_dump() + ) + + if ( + isinstance(effective_server_configs, list) + and len(effective_server_configs) == 1 + ): + default_openai_config_ = effective_server_configs[0] + else: + default_openai_config_ = effective_server_configs + + if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: + yaml_oai_config = yaml_oai_config[0] + + if isinstance(default_openai_config_, APIServerConfig) and isinstance( + yaml_oai_config, dict + ): + openai_config_dict = merge_dicts( + default_openai_config_.model_dump(), + yaml_oai_config, + oai_cli_passed_args, + ) + else: + openai_config_dict = {} + + server_manager_cli_passed_flags = {} + if "slurm" in cli_passed_flags: + server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"] + if "testing" in cli_passed_flags: + server_manager_cli_passed_flags["testing"] = cli_passed_flags[ + "testing" + ] + + server_manager_yaml_dict = {} + if "slurm" in yaml_config: + server_manager_yaml_dict["slurm"] = yaml_config["slurm"] + if "testing" in yaml_config: + server_manager_yaml_dict["testing"] = yaml_config["testing"] + + server_manager_config_dict = merge_dicts( + ServerManagerConfig().model_dump(), + server_manager_yaml_dict, + server_manager_cli_passed_flags, + ) + + env_config = type(default_env_config)(**env_config_dict) + server_manager_config = ServerManagerConfig( + **server_manager_config_dict + ) + openai_configs = resolve_openai_configs( + default_server_configs=effective_server_configs, + openai_config_dict=openai_config_dict, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + logger=logger, + ) + teacher_configs = cls._resolve_teacher_server_configs( + default_teacher_server_configs=default_teacher_server_configs, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + ) + + env_kwargs = { + "config": env_config, + "server_configs": openai_configs, + "slurm": server_manager_config.slurm, + "testing": server_manager_config.testing, + } + if teacher_configs is not None: + env_kwargs["teacher_server_configs"] = teacher_configs + env = cls(**env_kwargs) + rprint(env_config) + rprint(openai_configs) + if teacher_configs is not None: + rprint(teacher_configs) + + try: + loop = asyncio.get_running_loop() + task = loop.create_task(env.env_manager()) + loop.run_until_complete(task) + except RuntimeError: + asyncio.run(env.env_manager()) + + return CliServeConfig + + def __init__( + self, + config: TeacherDistillationConfig, + server_configs: Union[ServerBaseline, List[APIServerConfig]], + teacher_server_configs: Optional[ + Union[ServerBaseline, APIServerConfig, List[APIServerConfig]] + ] = None, + slurm: bool = False, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm=slurm, testing=testing) + self.teacher_server: Optional[ServerManager] = None + + if config.teacher_enabled: + if teacher_server_configs is None: + raise ValueError( + "teacher_enabled=True but no teacher server configuration was " + "provided. Pass teacher_server_configs=... when instantiating " + "the environment directly, or use the teacher-aware 'serve' CLI " + "path with --teacher.* flags. The generic BaseEnv 'process' and " + "'evaluate' commands do not currently wire teacher_server_configs." + ) + if isinstance(teacher_server_configs, APIServerConfig): + teacher_config_source = [teacher_server_configs] + else: + teacher_config_source = teacher_server_configs + self.teacher_server = ServerManager( + teacher_config_source, + slurm=False, + testing=False, + ) + if isinstance(teacher_config_source, list): + teacher_cfg = teacher_config_source[0] + else: + teacher_cfg = teacher_config_source + + teacher_tokenizer_name = ( + teacher_cfg.model_name + if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none") + else teacher_cfg.tokenizer_name + ) + self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name) + + # ------------------------------------------------------------------ + # Core fetch + # ------------------------------------------------------------------ + + def _validate_teacher_tokenizer_compatibility( + self, teacher_tokenizer_name: str + ) -> None: + student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" + if student_tok_name == teacher_tokenizer_name: + return + + try: + from transformers import AutoTokenizer + + teacher_tokenizer = AutoTokenizer.from_pretrained( + teacher_tokenizer_name, use_fast=True + ) + except Exception as exc: + raise ValueError( + "Cross-tokenizer distillation is not supported in this PR, and the " + f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to " + f"verify compatibility: {exc}" + ) from exc + + student_vocab = self.tokenizer.get_vocab() + teacher_vocab = teacher_tokenizer.get_vocab() + if student_vocab != teacher_vocab: + raise ValueError( + "Cross-tokenizer distillation is not supported in this PR. " + f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' " + f"and teacher tokenizer '{teacher_tokenizer_name}' do not match." + ) + + async def _fetch_teacher_for_sequence( + self, token_ids: List[int], top_k: int + ) -> Tuple[List[List[int]], List[List[float]]]: + assert self.teacher_server is not None + payload = await self.teacher_server.get_logprobs( + input_ids=token_ids, + top_k=top_k, + max_tokens=1, + split="train", + ) + return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"] + + # ------------------------------------------------------------------ + # Group enrichment + # ------------------------------------------------------------------ + + async def _attach_teacher_distillation( + self, group: ScoredDataGroup + ) -> ScoredDataGroup: + if not self.config.teacher_enabled or self.teacher_server is None: + return group + + seqs = group.get("tokens", []) + if not seqs: + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + + group_overrides = group.get("group_overrides") or {} + if group_overrides.get("skip_teacher_top_k", False): + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + + top_k = int(group_overrides.get("teacher_top_k", self.config.teacher_top_k)) + if top_k <= -1: + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + + tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs] + results = await asyncio.gather(*tasks, return_exceptions=True) + + distill_token_ids: List[List[List[int]]] = [] + distill_logprobs: List[List[List[float]]] = [] + for idx, result in enumerate(results): + if isinstance(result, Exception): + logger.warning( + "Teacher logprob fetch failed for seq %s: %s. " + "Dropping distill payload for this group.", + idx, + result, + ) + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + token_ids_k, logprobs_k = result + if len(token_ids_k) != len(logprobs_k): + logger.warning( + "Teacher prompt-topk length mismatch for seq %s (%s != %s). " + "Dropping distill payload for this group.", + idx, + len(token_ids_k), + len(logprobs_k), + ) + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + distill_token_ids.append(token_ids_k) + distill_logprobs.append(logprobs_k) + + group["distill_token_ids"] = distill_token_ids + group["distill_logprobs"] = distill_logprobs + return group + + async def handle_send_to_api( + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Any = None, + do_send_to_api: bool = True, + abort_on_any_max_length_exceeded: bool = True, + ): + groups = scored_data if isinstance(scored_data, list) else [scored_data] + enriched_groups: List[ScoredDataGroup] = [] + for group in groups: + if group is None: + continue + enriched_groups.append(await self._attach_teacher_distillation(group)) + + payload: Union[ScoredDataGroup, List[ScoredDataGroup]] + if isinstance(scored_data, list): + payload = enriched_groups + else: + payload = enriched_groups[0] if enriched_groups else scored_data + + return await super().handle_send_to_api( + payload, + item=item, + do_send_to_api=do_send_to_api, + abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded, + ) diff --git a/atroposlib/tests/test_server_logprobs.py b/atroposlib/tests/test_server_logprobs.py index 8cbd84ad7..7e545355e 100644 --- a/atroposlib/tests/test_server_logprobs.py +++ b/atroposlib/tests/test_server_logprobs.py @@ -1,13 +1,19 @@ """Tests for get_logprobs wrappers and server-manager routing.""" +import logging + import pytest +from atroposlib.envs.server_handling.openai_server import resolve_openai_configs from atroposlib.envs.server_handling.server_baseline import ( APIServer, APIServerConfig, AsyncSemWithAdaptiveWeight, ) from atroposlib.envs.server_handling.server_manager import ServerManager +from atroposlib.envs.server_handling.vllm_server import ( + resolve_openai_configs as resolve_vllm_configs, +) class _FakeAPIServer(APIServer): @@ -103,3 +109,49 @@ async def test_server_manager_get_logprobs_routes_to_most_available_server(): out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval") assert out_eval["server"] == "s1" assert s1.calls == 1 + + +def test_resolve_openai_configs_wraps_single_api_server_config_in_list(): + default_server_config = APIServerConfig( + model_name="test-model", + base_url="http://localhost:9001/v1", + api_key="x", + server_type="openai", + ) + merged_config = default_server_config.model_dump() + + server_configs = resolve_openai_configs( + default_server_configs=default_server_config, + openai_config_dict=merged_config, + yaml_config={}, + cli_passed_flags={}, + logger=logging.getLogger("test"), + ) + + assert isinstance(server_configs, list) + assert len(server_configs) == 1 + assert isinstance(server_configs[0], APIServerConfig) + assert server_configs[0].base_url == "http://localhost:9001/v1" + + +def test_resolve_vllm_configs_wraps_single_api_server_config_in_list(): + default_server_config = APIServerConfig( + model_name="test-model", + base_url="http://localhost:9001/v1", + api_key="x", + server_type="vllm", + ) + merged_config = default_server_config.model_dump() + + server_configs = resolve_vllm_configs( + default_server_configs=default_server_config, + openai_config_dict=merged_config, + yaml_config={}, + cli_passed_flags={}, + logger=logging.getLogger("test"), + ) + + assert isinstance(server_configs, list) + assert len(server_configs) == 1 + assert isinstance(server_configs[0], APIServerConfig) + assert server_configs[0].base_url == "http://localhost:9001/v1" diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py new file mode 100644 index 000000000..c88252188 --- /dev/null +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -0,0 +1,344 @@ +"""Tests for TeacherDistillationEnv distillation enrichment.""" + +from types import SimpleNamespace + +import pytest + +from atroposlib.envs.server_handling.server_baseline import APIServerConfig +from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv + + +class _FakeTeacherServer: + def __init__(self, fail_on_call: int = -1): + self.calls = 0 + self.fail_on_call = fail_on_call + + async def get_logprobs(self, **kwargs): + self.calls += 1 + if self.calls == self.fail_on_call: + raise RuntimeError("teacher backend failure") + seq = kwargs["input_ids"] + return { + "prompt_tokens": seq, + "prompt_topk_token_ids": [[tok, tok + 1] for tok in seq], + "prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq], + } + + +class _ConcreteTeacherEnv(TeacherDistillationEnv): + async def get_next_item(self): + return None + + async def evaluate(self, *args, **kwargs): + return None + + +class _DummyTokenizer: + name_or_path = "student-model" + + def get_vocab(self): + return {"a": 1} + + +class _CapturingServerManager: + def __init__(self, configs, slurm=False, testing=False): + self.configs = configs + self.slurm = slurm + self.testing = testing + + +@pytest.mark.asyncio +async def test_attach_teacher_distillation_success(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2) + env.teacher_server = _FakeTeacherServer() + + group = { + "tokens": [[1, 2, 3], [4, 5]], + "group_overrides": None, + "masks": [[-100, 2, 3], [-100, 5]], + "scores": [1.0, 0.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert out["distill_token_ids"] is not None + assert out["distill_logprobs"] is not None + assert len(out["distill_token_ids"]) == 2 + assert len(out["distill_token_ids"][0]) == 3 + assert len(out["distill_logprobs"][1]) == 2 + + +@pytest.mark.asyncio +async def test_attach_teacher_distillation_failure_drops_payload(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2) + env.teacher_server = _FakeTeacherServer(fail_on_call=2) + + group = { + "tokens": [[1, 2, 3], [4, 5]], + "group_overrides": None, + "masks": [[-100, 2, 3], [-100, 5]], + "scores": [1.0, 0.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None + + +@pytest.mark.asyncio +async def test_attach_teacher_distillation_negative_topk_skips_fetch(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=-1) + env.teacher_server = _FakeTeacherServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": None, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert env.teacher_server.calls == 0 + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None + + +@pytest.mark.asyncio +async def test_attach_teacher_distillation_zero_topk_passthrough(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0) + env.teacher_server = _FakeTeacherServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": None, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert env.teacher_server.calls == 1 + assert out["distill_token_ids"] is not None + assert out["distill_logprobs"] is not None + + +@pytest.mark.asyncio +async def test_attach_teacher_distillation_group_override_topk_is_used(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0) + + seen_topks = [] + + async def _fake_fetch(seq, top_k): + seen_topks.append(top_k) + return [[tok] for tok in seq], [[-0.1] for _ in seq] + + env.teacher_server = object() + env._fetch_teacher_for_sequence = _fake_fetch + + group = { + "tokens": [[1, 2, 3], [4, 5]], + "group_overrides": {"teacher_top_k": 7}, + "masks": [[-100, 2, 3], [-100, 5]], + "scores": [1.0, 0.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert seen_topks == [7, 7] + assert out["distill_token_ids"] is not None + assert out["distill_logprobs"] is not None + + +@pytest.mark.asyncio +async def test_attach_teacher_distillation_group_override_can_skip_fetch(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2) + env.teacher_server = _FakeTeacherServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": {"skip_teacher_top_k": True}, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert env.teacher_server.calls == 0 + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None + + +def test_teacher_tokenizer_mismatch_raises(monkeypatch): + env = object.__new__(_ConcreteTeacherEnv) + + class _StudentTokenizer: + name_or_path = "student-model" + + def get_vocab(self): + return {"a": 1} + + class _TeacherTokenizer: + def get_vocab(self): + return {"b": 1} + + env.tokenizer = _StudentTokenizer() + monkeypatch.setattr( + "transformers.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: _TeacherTokenizer(), + ) + + with pytest.raises( + ValueError, match="Cross-tokenizer distillation is not supported" + ): + TeacherDistillationEnv._validate_teacher_tokenizer_compatibility( + env, + teacher_tokenizer_name="teacher-model", + ) + + +def test_init_requires_teacher_server_source(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + def _fake_base_init(self, config, server_configs, slurm=False, testing=False): + self.config = config + self.tokenizer = _DummyTokenizer() + + monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init) + + config = SimpleNamespace( + teacher_enabled=True, + teacher_top_k=0, + ) + with pytest.raises( + ValueError, match="no teacher server configuration was provided" + ): + _ConcreteTeacherEnv( + config=config, + server_configs=[], + ) + + +def test_init_uses_explicit_teacher_server_configs(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + called = {} + + def _fake_base_init(self, config, server_configs, slurm=False, testing=False): + self.config = config + self.tokenizer = _DummyTokenizer() + + def _fake_validate(self, teacher_tokenizer_name): + called["teacher_tokenizer_name"] = teacher_tokenizer_name + + monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init) + monkeypatch.setattr(module, "ServerManager", _CapturingServerManager) + monkeypatch.setattr( + _ConcreteTeacherEnv, + "_validate_teacher_tokenizer_compatibility", + _fake_validate, + ) + + explicit_cfg = APIServerConfig( + model_name="explicit-model", + tokenizer_name="explicit-tokenizer", + base_url="http://explicit/v1", + api_key="x", + server_type="vllm", + ) + config = SimpleNamespace( + teacher_enabled=True, + teacher_top_k=0, + ) + + env = _ConcreteTeacherEnv( + config=config, + server_configs=[], + teacher_server_configs=[explicit_cfg], + ) + + assert isinstance(env.teacher_server, _CapturingServerManager) + assert env.teacher_server.configs == [explicit_cfg] + assert called["teacher_tokenizer_name"] == "explicit-tokenizer" + + +def test_init_wraps_bare_teacher_api_server_config(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + called = {} + + def _fake_base_init(self, config, server_configs, slurm=False, testing=False): + self.config = config + self.tokenizer = _DummyTokenizer() + + def _fake_validate(self, teacher_tokenizer_name): + called["teacher_tokenizer_name"] = teacher_tokenizer_name + + monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init) + monkeypatch.setattr(module, "ServerManager", _CapturingServerManager) + monkeypatch.setattr( + _ConcreteTeacherEnv, + "_validate_teacher_tokenizer_compatibility", + _fake_validate, + ) + + explicit_cfg = APIServerConfig( + model_name="explicit-model", + tokenizer_name="explicit-tokenizer", + base_url="http://explicit/v1", + api_key="x", + server_type="vllm", + ) + config = SimpleNamespace( + teacher_enabled=True, + teacher_top_k=0, + ) + + env = _ConcreteTeacherEnv( + config=config, + server_configs=[], + teacher_server_configs=explicit_cfg, + ) + + assert isinstance(env.teacher_server, _CapturingServerManager) + assert env.teacher_server.configs == [explicit_cfg] + assert called["teacher_tokenizer_name"] == "explicit-tokenizer" + + +def test_resolve_teacher_server_configs_returns_none_when_unset(): + assert ( + _ConcreteTeacherEnv._resolve_teacher_server_configs( + default_teacher_server_configs=None, + yaml_config={}, + cli_passed_flags={}, + ) + is None + ) + + +def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + captured = {} + + def _fake_resolve(**kwargs): + captured.update(kwargs) + return ["resolved"] + + monkeypatch.setattr(module, "resolve_openai_configs", _fake_resolve) + + default_cfg = APIServerConfig( + model_name="teacher-model", + base_url="http://teacher/v1", + api_key="x", + server_type="vllm", + ) + + out = _ConcreteTeacherEnv._resolve_teacher_server_configs( + default_teacher_server_configs=default_cfg, + yaml_config={"teacher": {"tokenizer_name": "teacher-tokenizer"}}, + cli_passed_flags={"teacher.base_url": "http://override/v1"}, + ) + + assert out == ["resolved"] + assert captured["openai_config_dict"]["base_url"] == "http://override/v1" + assert captured["openai_config_dict"]["tokenizer_name"] == "teacher-tokenizer" + assert captured["yaml_config"] == { + "openai": {"tokenizer_name": "teacher-tokenizer"} + } + assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"} diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 87823526e..de13f8c99 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -232,7 +232,6 @@ async def collect_trajectories( ) async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - chat_completions = await managed.chat_completion( messages=[{"role": "system", "content": system_prompt}, user_message], n=self.config.group_size, @@ -352,7 +351,7 @@ async def score( # Apply linear penalty scaling from 1.0 down to 0.0 scores["scores"].append(1.0 - percentage_of_range) if all([scores["scores"][0] == score for score in scores["scores"]]): - return None # If all the same, we return None + return None return scores else: # If the gold solution is not parseable, we return None diff --git a/environments/gsm8k_server_teacher_distill.py b/environments/gsm8k_server_teacher_distill.py new file mode 100644 index 000000000..59106f101 --- /dev/null +++ b/environments/gsm8k_server_teacher_distill.py @@ -0,0 +1,58 @@ +from typing import Tuple + +from atroposlib.envs.base import APIServerConfig, ServerBaseline +from atroposlib.envs.teacher_distillation_env import ( + TeacherDistillationConfig, + TeacherDistillationEnv, +) +from environments.gsm8k_server import GSM8kEnv + + +class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): + """ + GSM8K environment variant that enables TeacherDistillationEnv config fields. + + This preserves the original `gsm8k_server.py` while providing a separate entrypoint + for teacher-distillation data collection. + """ + + name = "gsm8k_teacher_distill" + env_config_cls = TeacherDistillationConfig + + @classmethod + def config_init(cls) -> Tuple[TeacherDistillationConfig, ServerBaseline]: + env_config = TeacherDistillationConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, + wandb_name="gsm8k_teacher_distill", + teacher_enabled=True, + teacher_top_k=4, + ) + server_config = APIServerConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, + ) + return env_config, server_config + + @classmethod + def teacher_config_init(cls) -> APIServerConfig: + return APIServerConfig( + base_url="http://localhost:9003/v1", + model_name="mock-teacher", + api_key="", + server_type="vllm", + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + timeout=1200, + ) + + +if __name__ == "__main__": + GSM8kTeacherDistillEnv.cli() diff --git a/example_trainer/README.md b/example_trainer/README.md index a68206142..4563e12cc 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -304,6 +304,38 @@ environment uses the `/generate` path and includes token-level 4. Trainer extracts and aligns logprobs with training labels 5. GRPO loss uses these rollout logprobs in importance-ratio terms +### 1b. Teacher distillation requires the same tokenizer + +When distillation data is attached to Atropos batches, the trainer treats +`distill_token_ids` as indices into the student's logit tensor. That only works +if the teacher and student share the same tokenizer vocabulary. + +What to configure on the environment side: + +```bash +--env.teacher_enabled true \ +--teacher.base_url "http://localhost:9003/v1" \ +--teacher.model_name "$TEACHER_MODEL" \ +--teacher.server_type vllm \ +--env.teacher_top_k 8 +``` + +If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier, +also set `--teacher.tokenizer_name ...` so the env can validate +tokenizer compatibility. + +The teacher-aware CLI path is currently wired for `serve`. If +`teacher_enabled=True`, the generic `process` and `evaluate` commands are not +teacher-aware and will fail loudly unless the environment is instantiated +manually with `teacher_server_configs=...`. + +Why cross-tokenizer conversion is not acceptable here: + +- Teacher token ID `1234` and student token ID `1234` can correspond to different text. +- Re-tokenizing teacher text changes token boundaries, so teacher position `i` may no longer correspond to student position `i`. +- Remapping teacher top-k tokens back into student vocab can collapse multiple teacher candidates into one student token or expand one teacher token into multiple student tokens. +- The current distillation loss expects exact per-position supervision in student token space, so an approximate remapping would silently produce misleading targets. + ### 2. Clipping ```bash diff --git a/example_trainer/api.py b/example_trainer/api.py index 21c4288e8..1bc8a1bdb 100644 --- a/example_trainer/api.py +++ b/example_trainer/api.py @@ -99,7 +99,11 @@ def get_batch(url: str = "http://localhost:8000"): Raises: RuntimeError: If trainer is not registered or other API error """ - data = requests.get(f"{url}/batch", timeout=10).json() + response = requests.get( + f"{url}/batch", + timeout=10, + ) + data = response.json() # Check if there was an error (trainer not registered) if data.get("status") == "error": diff --git a/example_trainer/run_gsm8k_lora_matrix.sh b/example_trainer/run_gsm8k_lora_matrix.sh index 121106e07..48bad3ce9 100755 --- a/example_trainer/run_gsm8k_lora_matrix.sh +++ b/example_trainer/run_gsm8k_lora_matrix.sh @@ -248,8 +248,7 @@ run_shared_vllm() { --port "$vllm_port" \ --gpu-memory-utilization "$SHARED_GPU_MEMORY_UTILIZATION" \ --max-model-len "$MAX_MODEL_LEN" \ - --dtype "$DTYPE" \ - --enforce-eager + --dtype "$DTYPE" if [[ "$DRY_RUN" == "1" ]]; then log "[DRY RUN] wait for http://localhost:${vllm_port}/health" else diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh new file mode 100755 index 000000000..fead9ba66 --- /dev/null +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -0,0 +1,313 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Single-terminal teacher-distillation runner. +# Starts everything in the background from ONE shell that has GPU access: +# 1) Atropos API +# 2) Student vLLM server +# 3) Teacher vLLM server +# 4) GSM8K teacher-distill environment +# 5) Example trainer (foreground) +# +# Usage: +# chmod +x example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +# ./example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +# +# Optional overrides: +# STUDENT_MODEL="Qwen/Qwen3-4B-Instruct-2507-FP8" +# TEACHER_MODEL="Qwen/Qwen3-30B-A3B-Instruct-2507" +# STUDENT_GPUS="0" +# TEACHER_GPUS="4,5,6,7" +# TRAINER_GPUS="0" +# STUDENT_TP=1 +# TEACHER_TP=4 +# API_PORT=8002 +# STUDENT_PORT=9001 +# TEACHER_PORT=9003 +# TRAINING_STEPS=100 +# DISTILL_COEF=0.2 +# DISTILL_TEMPERATURE=1.0 +# TEACHER_TOP_K=8 +# DRY_RUN=1 + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +LAUNCH_DIR="$PWD" +cd "$ROOT_DIR" + +PYTHON_BIN="${PYTHON_BIN:-python3}" +STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B}" +TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}" + +STUDENT_GPUS="${STUDENT_GPUS:-0}" +TEACHER_GPUS="${TEACHER_GPUS:-4,5,6,7}" +TRAINER_GPUS="${TRAINER_GPUS:-$STUDENT_GPUS}" + +STUDENT_TP="${STUDENT_TP:-1}" +TEACHER_TP="${TEACHER_TP:-4}" + +API_PORT="${API_PORT:-8002}" +STUDENT_PORT="${STUDENT_PORT:-9001}" +TEACHER_PORT="${TEACHER_PORT:-9003}" + +TRAINING_STEPS="${TRAINING_STEPS:-20}" +BATCH_SIZE="${BATCH_SIZE:-1}" +GRAD_ACCUM="${GRAD_ACCUM:-4}" +LR="${LR:-1e-5}" +WARMUP_STEPS="${WARMUP_STEPS:-0}" +CLIP_EPS="${CLIP_EPS:-0.2}" +MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}" +TEACHER_MAX_MODEL_LEN="${TEACHER_MAX_MODEL_LEN:-32768}" +# Trainer seq_len must be larger than ENV_MAX_TOKEN_LENGTH to accommodate +# chat template overhead (~400-800 tokens for Qwen3 thinking format). +TRAINER_SEQ_LEN="${TRAINER_SEQ_LEN:-20480}" +ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-16384}" +DISTILL_COEF="${DISTILL_COEF:-0.2}" +DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" +TEACHER_TOP_K="${TEACHER_TOP_K:-8}" + +WANDB_PROJECT="${WANDB_PROJECT:-gsm8k-teacher-distill}" +WANDB_GROUP="${WANDB_GROUP:-}" + +STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.60}" +TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.85}" +DTYPE="${DTYPE:-bfloat16}" +SAVE_DIR="${SAVE_DIR:-${LAUNCH_DIR}/saves/gsm8k_teacher_distill}" +LOG_DIR="${LOG_DIR:-${LAUNCH_DIR}/logs/gsm8k_teacher_distill}" +BRIDGE_DIR="${BRIDGE_DIR:-${LOG_DIR}/bridge}" +DRY_RUN="${DRY_RUN:-0}" + +ENV_GROUP_SIZE="${ENV_GROUP_SIZE:-4}" +ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-8}" +ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}" +ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}" +ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-1}" +ENV_WORKER_TIMEOUT="${ENV_WORKER_TIMEOUT:-1800}" + +RUN_PIDS=() +RUN_PORTS=() + +log() { + echo "[$(date '+%H:%M:%S')] $*" +} + +kill_port() { + local port="$1" + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] skip port cleanup for :${port}" + return 0 + fi + if lsof -i ":${port}" -sTCP:LISTEN >/dev/null 2>&1; then + lsof -ti ":${port}" | xargs -r kill -9 || true + fi +} + +wait_for_http() { + local url="$1" + local timeout="${2:-240}" + local name="${3:-endpoint}" + local start + start="$(date +%s)" + while true; do + if curl -fsS "$url" >/dev/null 2>&1; then + log "Ready: ${name} (${url})" + return 0 + fi + if (( "$(date +%s)" - start > timeout )); then + log "Timeout waiting for ${name}: ${url}" + return 1 + fi + sleep 2 + done +} + +start_process() { + local name="$1" + local logfile="$2" + shift 2 + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] start ${name} (log: ${logfile})" + printf ' ' + printf '%q ' "$@" + printf '\n' + return 0 + fi + log "Starting ${name} (log: ${logfile})" + "$@" >"$logfile" 2>&1 & + local pid=$! + RUN_PIDS+=("$pid") + log "${name} PID=${pid}" +} + +cleanup_all() { + log "Cleaning up processes..." + for pid in "${RUN_PIDS[@]:-}"; do + kill "$pid" >/dev/null 2>&1 || true + done + sleep 1 + for pid in "${RUN_PIDS[@]:-}"; do + kill -9 "$pid" >/dev/null 2>&1 || true + done + for port in "${RUN_PORTS[@]:-}"; do + kill_port "$port" + done +} + +trap cleanup_all EXIT INT TERM + +mkdir -p "$LOG_DIR" "$SAVE_DIR" "$BRIDGE_DIR" +RUN_PORTS+=("$API_PORT" "$STUDENT_PORT" "$TEACHER_PORT") +kill_port "$API_PORT" +kill_port "$STUDENT_PORT" +kill_port "$TEACHER_PORT" + +log "Config:" +log " student=${STUDENT_MODEL}" +log " teacher=${TEACHER_MODEL}" +log " gpus student=${STUDENT_GPUS}, teacher=${TEACHER_GPUS}, trainer=${TRAINER_GPUS}" +log " ports api=${API_PORT}, student=${STUDENT_PORT}, teacher=${TEACHER_PORT}" +log " logs=${LOG_DIR}" +log " saves=${SAVE_DIR}" +log " bridge=${BRIDGE_DIR}" +log " env max_token_length=${ENV_MAX_TOKEN_LENGTH}, env workers=${ENV_MAX_WORKERS_PER_NODE}, env worker_timeout=${ENV_WORKER_TIMEOUT}" +log " wandb project=${WANDB_PROJECT}${WANDB_GROUP:+, group=${WANDB_GROUP}}" + +# Shared-vLLM attach path currently expects the student server to expose +# unsharded weights. Keep the student on TP=1 and the trainer on the same GPU set. +if [[ "$STUDENT_TP" != "1" ]]; then + log "ERROR: shared_vllm teacher-distill runner currently requires STUDENT_TP=1." + log " The current attach path does not support TP-sharded student bridge weights." + exit 2 +fi + +if [[ "$TRAINER_GPUS" != "$STUDENT_GPUS" ]]; then + log "ERROR: TRAINER_GPUS must match STUDENT_GPUS for shared_vllm mode." + log " Got student=${STUDENT_GPUS}, trainer=${TRAINER_GPUS}" + exit 2 +fi + +# 1) Atropos API +start_process "run_api" "${LOG_DIR}/run_api.log" \ + run-api --port "$API_PORT" +if [[ "$DRY_RUN" == "0" ]]; then + wait_for_http "http://localhost:${API_PORT}/info" 180 "run-api" +fi + +# 2) Student vLLM server +start_process "student_vllm" "${LOG_DIR}/student_vllm.log" \ + env CUDA_VISIBLE_DEVICES="$STUDENT_GPUS" VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR="$BRIDGE_DIR" \ + "$PYTHON_BIN" -m example_trainer.vllm_api_server \ + --model "$STUDENT_MODEL" \ + --port "$STUDENT_PORT" \ + --tensor-parallel-size "$STUDENT_TP" \ + --gpu-memory-utilization "$STUDENT_GPU_MEMORY_UTILIZATION" \ + --max-model-len "$MAX_MODEL_LEN" \ + --dtype "$DTYPE" +if [[ "$DRY_RUN" == "0" ]]; then + wait_for_http "http://localhost:${STUDENT_PORT}/health" 420 "student vLLM" +fi + +# 3) Teacher vLLM server +start_process "teacher_vllm" "${LOG_DIR}/teacher_vllm.log" \ + env CUDA_VISIBLE_DEVICES="$TEACHER_GPUS" \ + "$PYTHON_BIN" -m example_trainer.vllm_api_server \ + --model "$TEACHER_MODEL" \ + --port "$TEACHER_PORT" \ + --tensor-parallel-size "$TEACHER_TP" \ + --gpu-memory-utilization "$TEACHER_GPU_MEMORY_UTILIZATION" \ + --max-model-len "$TEACHER_MAX_MODEL_LEN" \ + --dtype "$DTYPE" +if [[ "$DRY_RUN" == "0" ]]; then + wait_for_http "http://localhost:${TEACHER_PORT}/health" 1800 "teacher vLLM" +fi + +# 4) Teacher-distill GSM8K env +start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \ + "$PYTHON_BIN" environments/gsm8k_server_teacher_distill.py serve \ + --env.group_size "$ENV_GROUP_SIZE" \ + --env.batch_size "$ENV_BATCH_SIZE" \ + --env.total_steps "$ENV_TOTAL_STEPS" \ + --env.steps_per_eval "$ENV_STEPS_PER_EVAL" \ + --env.max_num_workers_per_node "$ENV_MAX_WORKERS_PER_NODE" \ + --env.max_token_length "$ENV_MAX_TOKEN_LENGTH" \ + --env.worker_timeout "$ENV_WORKER_TIMEOUT" \ + --env.rollout_server_url "http://localhost:${API_PORT}" \ + --env.use_wandb true \ + --env.wandb_name "gsm8k-teacher-distill" \ + --env.teacher_enabled true \ + --teacher.base_url "http://localhost:${TEACHER_PORT}/v1" \ + --teacher.model_name "$TEACHER_MODEL" \ + --teacher.server_type vllm \ + --env.teacher_top_k "$TEACHER_TOP_K" \ + --env.ensure_scores_are_not_same false \ + --openai.api_key "dummy" \ + --openai.base_url "http://localhost:${STUDENT_PORT}/v1" \ + --openai.model_name "$STUDENT_MODEL" \ + --openai.tokenizer_name "$STUDENT_MODEL" \ + --openai.server_type vllm + +log "All services launched." +log "Run logs:" +log " ${LOG_DIR}/run_api.log" +log " ${LOG_DIR}/student_vllm.log" +log " ${LOG_DIR}/teacher_vllm.log" +log " ${LOG_DIR}/env.log" + +# 5) Trainer (background) +TRAINER_CMD=( + env + CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" + PYTHONUNBUFFERED=1 + "$PYTHON_BIN" + -u + -m + example_trainer.grpo + --model-name "$STUDENT_MODEL" + --weight-bridge-mode shared_vllm + --device cuda:0 + --save-path "$SAVE_DIR" + --atropos-url "http://localhost:${API_PORT}" + --vllm-port "$STUDENT_PORT" + --vllm-config-path "${BRIDGE_DIR}/vllm_bridge_config.json" + --training-steps "$TRAINING_STEPS" + --batch-size "$BATCH_SIZE" + --gradient-accumulation-steps "$GRAD_ACCUM" + --warmup-steps "$WARMUP_STEPS" + --lr "$LR" + --clip-eps "$CLIP_EPS" + --seq-len "$TRAINER_SEQ_LEN" + --distill-enabled + --distill-coef "$DISTILL_COEF" + --distill-temperature "$DISTILL_TEMPERATURE" + --use-wandb + --wandb-project "$WANDB_PROJECT" +) +if [[ -n "$WANDB_GROUP" ]]; then + TRAINER_CMD+=(--wandb-group "$WANDB_GROUP") +fi + +if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] trainer command:" + printf ' ' + printf '%q ' "${TRAINER_CMD[@]}" + printf '\n' + exit 0 +fi + +start_process "trainer" "${LOG_DIR}/trainer.log" "${TRAINER_CMD[@]}" + +log "All processes running in background." +log "" +log "Monitor with:" +log " tail -f ${LOG_DIR}/trainer.log" +log " tail -f ${LOG_DIR}/env.log" +log " tail -f ${LOG_DIR}/student_vllm.log" +log " tail -f ${LOG_DIR}/teacher_vllm.log" +log "" +log "Test endpoints:" +log " curl -s http://localhost:${STUDENT_PORT}/health" +log " curl -s http://localhost:${TEACHER_PORT}/health" +log " curl -s http://localhost:${STUDENT_PORT}/bridge/is_paused | jq ." +log "" +log "To stop all processes:" +log " kill ${RUN_PIDS[*]:-} 2>/dev/null; sleep 1; kill -9 ${RUN_PIDS[*]:-} 2>/dev/null" +trap - EXIT INT TERM diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 2846f14fb..24d403261 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -325,6 +325,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: final_output = request_output except asyncio.CancelledError: + logger.warning("POST /generate cancelled request_id=%s", request_id) return Response(status_code=499) assert final_output is not None @@ -348,6 +349,26 @@ async def stream_results() -> AsyncGenerator[bytes, None]: ret["prompt_token_ids"] = final_output.prompt_token_ids ret["token_ids"] = [x.token_ids for x in final_output.outputs] + if ( + sampling_params.prompt_logprobs is not None + and final_output.prompt_logprobs is not None + ): + ret["prompt_logprobs"] = [ + ( + {int(tok_id): lp.logprob for tok_id, lp in pos.items()} + if pos is not None + else None + ) + for pos in final_output.prompt_logprobs + ] + + logger.info( + "POST /generate completed request_id=%s outputs=%s finish_reasons=%s", + request_id, + len(text_outputs), + finish_reasons, + ) + return JSONResponse(ret)