Skip to content
Open
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
f44eb81
teacher env init
J-SUPHA Mar 6, 2026
530fed2
testing set up
J-SUPHA Mar 6, 2026
d5ca760
command change
J-SUPHA Mar 7, 2026
ad364ac
increase timeout cause vllm is super slow all of a sudden
J-SUPHA Mar 8, 2026
985311e
trial
J-SUPHA Mar 8, 2026
e563352
quicker training
J-SUPHA Mar 8, 2026
81f90a6
forgot something easy
J-SUPHA Mar 8, 2026
4f33ab8
apparently not so easy
J-SUPHA Mar 9, 2026
bb2736d
next
J-SUPHA Mar 10, 2026
64794e7
sneaky bug
J-SUPHA Mar 10, 2026
09ad401
sneaky bug logging
J-SUPHA Mar 10, 2026
d1fd89f
non blocking test
J-SUPHA Mar 10, 2026
057c9fe
shorten worker timeout
J-SUPHA Mar 10, 2026
e84686b
remove enforce eager
J-SUPHA Mar 10, 2026
e79af5f
testing config
J-SUPHA Mar 11, 2026
abba562
testing config
J-SUPHA Mar 11, 2026
82be871
testing config
J-SUPHA Mar 11, 2026
98a5d3b
testing config
J-SUPHA Mar 11, 2026
78c0a6d
tokenizer bug
J-SUPHA Mar 11, 2026
f1cfc13
tokenizer bug
J-SUPHA Mar 11, 2026
c275687
tokenizer bug
J-SUPHA Mar 11, 2026
3a440f8
tokenizer bug
J-SUPHA Mar 11, 2026
b457a67
tokenizer bug
J-SUPHA Mar 11, 2026
2f371e0
tokenizer bug
J-SUPHA Mar 11, 2026
8a348be
tokenizer bug
J-SUPHA Mar 11, 2026
34a3936
tokenizer bug
J-SUPHA Mar 12, 2026
fd5b426
tokenizer bug
J-SUPHA Mar 12, 2026
c37516b
tokenizer bug
J-SUPHA Mar 12, 2026
a54dfe7
tokenizer bug
J-SUPHA Mar 12, 2026
62ef2fc
training kernel
J-SUPHA Mar 12, 2026
c26432b
training kernel
J-SUPHA Mar 12, 2026
7ec622a
training ideas
J-SUPHA Mar 12, 2026
a43b0b7
training kernel
J-SUPHA Mar 12, 2026
690e670
investigating weird training issue
J-SUPHA Mar 12, 2026
3df0e45
investigating weird training issue
J-SUPHA Mar 13, 2026
d8857eb
investigating weird training issue
J-SUPHA Mar 13, 2026
d1b0dee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
600c54f
clean log
J-SUPHA Mar 13, 2026
862cd36
clean logging
J-SUPHA Mar 13, 2026
148a4fd
remove training code
J-SUPHA Mar 13, 2026
a1b545c
remove cross tokenization and fix location of configs
J-SUPHA Mar 13, 2026
994e9c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
322e7e6
remove comments
J-SUPHA Mar 13, 2026
a8cdb53
address problems
J-SUPHA Mar 13, 2026
82964b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
697c594
changes
J-SUPHA Mar 13, 2026
6c56479
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
1b8ff07
adding tests
J-SUPHA Mar 13, 2026
12ba3cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
a171358
structural changes
J-SUPHA Mar 13, 2026
3a85ede
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
9bd299b
better logging for devex
J-SUPHA Mar 14, 2026
f053c77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2026
805a0c0
revert to similar structure
J-SUPHA Mar 14, 2026
7aba0d3
fresh eyes check
J-SUPHA Mar 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
"filename": "README.md",
"hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5",
"is_verified": false,
"line_number": 454
"line_number": 495
}
],
"SLURM.md": [
Expand Down Expand Up @@ -561,5 +561,5 @@
}
]
},
"generated_at": "2026-03-02T22:46:56Z"
"generated_at": "2026-03-13T17:20:46Z"
}
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,47 @@ curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!
- Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.).
- Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR.

### TeacherDistillationEnv follow-up

The follow-up teacher environment uses a dedicated teacher server config and
attaches teacher prompt logprobs before the group is sent to the API.

Teacher config shape:

```python
TeacherDistillationConfig(
teacher_enabled=True,
teacher_server=APIServerConfig(
base_url="http://localhost:9003/v1",
model_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
api_key="",
server_type="vllm",
),
teacher_top_k=8,
)
```

CLI shape:

```bash
--env.teacher_enabled true \
--env.teacher_server.base_url "http://localhost:9003/v1" \
--env.teacher_server.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \
--env.teacher_server.server_type vllm \
--env.teacher_top_k 8
```

Tokenizer requirement:

- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary.
- If the tokenizers do not match, `TeacherDistillationEnv` raises an error instead of attempting token conversion.

Why same-tokenizer is required:

- `distill_token_ids` are consumed as student-vocabulary IDs by the trainer.
- If the teacher uses a different vocabulary, the same integer token ID refers to different text on the teacher and student sides.
- A decode/re-tokenize/remap pipeline is not a safe drop-in fix because it changes both token positions and token identities, which breaks the exact per-position token supervision that the current distillation loss assumes.

---

## Testing and Debugging Tools
Expand Down
11 changes: 2 additions & 9 deletions atroposlib/envs/server_handling/openai_server.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this may need to be reverted?

Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def resolve_openai_configs(
f"Error creating final OpenAI configuration from merged settings: {e}\n"
f"Merged Dict: {openai_config_dict}"
) from e
server_configs = final_openai_config
server_configs = [final_openai_config]
elif isinstance(default_server_configs, ServerBaseline):
# Pure ServerBaseline (not APIServerConfig) - no CLI overrides possible
logger.info("Using ServerBaseline configuration.")
Expand All @@ -230,15 +230,8 @@ def resolve_openai_configs(
f"Merged Dict: {openai_config_dict}"
) from e

if isinstance(default_server_configs, APIServerConfig):
server_configs = final_openai_config
elif isinstance(default_server_configs, list):
if isinstance(default_server_configs, list):
server_configs = [final_openai_config]
else:
logger.warning(
f"Unexpected type for default_server_configs: {type(default_server_configs)}. "
f"Proceeding with single OpenAI server configuration based on merged settings."
)
server_configs = [final_openai_config]

return server_configs
4 changes: 0 additions & 4 deletions atroposlib/envs/server_handling/vllm_server.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,6 @@ def resolve_openai_configs(
elif isinstance(default_server_configs, list):
server_configs = [final_openai_config]
else:
logger.warning(
f"Unexpected type for default_server_configs: {type(default_server_configs)}. "
f"Proceeding with single OpenAI server configuration based on merged settings."
)
server_configs = [final_openai_config]

return server_configs
Expand Down
216 changes: 216 additions & 0 deletions atroposlib/envs/teacher_distillation_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""
Teacher distillation environment layer.

This module adds teacher prompt-logprob fetching on top of BaseEnv without
modifying BaseEnv transport behavior.

This implementation supports same-tokenizer distillation only. The teacher and
student must share the same tokenizer vocabulary so the student's token IDs can
be forwarded directly to the teacher and the returned teacher top-k token IDs
can be looked up directly in the student's logits.
"""

from __future__ import annotations

import asyncio
import logging
from abc import ABC
from typing import Any, List, Optional, Tuple, Union

from pydantic import Field

from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
from .server_handling.server_baseline import APIServerConfig, ServerBaseline
from .server_handling.server_manager import ServerManager

logger = logging.getLogger(__name__)


class TeacherDistillationConfig(BaseEnvConfig):
teacher_enabled: bool = Field(
default=False,
description="Whether to fetch teacher prompt logprobs for distillation.",
)
teacher_server: Optional[APIServerConfig] = Field(
default=None,
description="Teacher inference server configuration.",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, i probably commented poorly, it should be the same as how we setup the server_manager, so we may need to pass in a new thing to init

Copy link
Collaborator Author

@J-SUPHA J-SUPHA Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated this to follow this pattern. I removed teacher_server from TeacherDistillationConfig, so the env config now only carries env-level knobs like teacher_enabled and teacher_top_k. Teacher server wiring is now passed separately via teacher_server_configs at init

teacher_top_k: int = Field(
default=1,
ge=1,
description="Top-k prompt logprobs to fetch per token position.",
)


class TeacherDistillationEnv(BaseEnv, ABC):
"""
BaseEnv subclass that enriches scored groups with teacher distillation arrays.

Distillation payload shape:
- distill_token_ids: [sequence][position][k] (student vocab IDs)
- distill_logprobs: [sequence][position][k]
"""

env_config_cls = TeacherDistillationConfig

def __init__(
self,
config: TeacherDistillationConfig,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
slurm: bool = False,
testing: bool = False,
):
super().__init__(config, server_configs, slurm=slurm, testing=testing)
self.teacher_server: Optional[ServerManager] = None

if config.teacher_enabled:
if config.teacher_server is None:
raise ValueError(
"teacher_enabled=True requires a teacher_server configuration."
)
teacher_cfg = config.teacher_server.model_copy(
update={
"tokenizer_name": (
config.teacher_server.model_name
if config.teacher_server.tokenizer_name in ("", "none")
else config.teacher_server.tokenizer_name
),
"timeout": 1200,
}
)
self.teacher_server = ServerManager(
[teacher_cfg],
slurm=False,
testing=False,
)
self._validate_teacher_tokenizer_compatibility(teacher_cfg.tokenizer_name)

# ------------------------------------------------------------------
# Core fetch
# ------------------------------------------------------------------

def _validate_teacher_tokenizer_compatibility(
self, teacher_tokenizer_name: str
) -> None:
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
if student_tok_name == teacher_tokenizer_name:
return

try:
from transformers import AutoTokenizer

teacher_tokenizer = AutoTokenizer.from_pretrained(
teacher_tokenizer_name, use_fast=True
)
except Exception as exc:
raise ValueError(
"Cross-tokenizer distillation is not supported in this PR, and the "
f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to "
f"verify compatibility: {exc}"
) from exc

student_vocab = self.tokenizer.get_vocab()
teacher_vocab = teacher_tokenizer.get_vocab()
if student_vocab != teacher_vocab:
raise ValueError(
"Cross-tokenizer distillation is not supported in this PR. "
f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' "
f"and teacher tokenizer '{teacher_tokenizer_name}' do not match."
)

async def _fetch_teacher_for_sequence(
self, token_ids: List[int], top_k: int
) -> Tuple[List[List[int]], List[List[float]]]:
assert self.teacher_server is not None
payload = await self.teacher_server.get_logprobs(
input_ids=token_ids,
top_k=top_k,
max_tokens=1,
split="train",
)
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]

# ------------------------------------------------------------------
# Group enrichment
# ------------------------------------------------------------------

async def _attach_teacher_distillation(
self, group: ScoredDataGroup
) -> ScoredDataGroup:
if not self.config.teacher_enabled or self.teacher_server is None:
return group

seqs = group.get("tokens", [])
if not seqs:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group

top_k = int(
(group.get("group_overrides") or {}).get(
"teacher_top_k", self.config.teacher_top_k
)
)
top_k = max(1, top_k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max should be 0, because prompt logprobs are (selected token + topk), disabled would be setting it to -1 or lower. I would also be amenable to a group override that's skip_teacher_top_k


tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs]
results = await asyncio.gather(*tasks, return_exceptions=True)

distill_token_ids: List[List[List[int]]] = []
distill_logprobs: List[List[List[float]]] = []
for idx, result in enumerate(results):
if isinstance(result, Exception):
logger.warning(
"Teacher logprob fetch failed for seq %s: %s. "
"Dropping distill payload for this group.",
idx,
result,
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
token_ids_k, logprobs_k = result
if len(token_ids_k) != len(logprobs_k):
logger.warning(
"Teacher prompt-topk length mismatch for seq %s (%s != %s). "
"Dropping distill payload for this group.",
idx,
len(token_ids_k),
len(logprobs_k),
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
distill_token_ids.append(token_ids_k)
distill_logprobs.append(logprobs_k)

group["distill_token_ids"] = distill_token_ids
group["distill_logprobs"] = distill_logprobs
return group

async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Any = None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
groups = scored_data if isinstance(scored_data, list) else [scored_data]
enriched_groups: List[ScoredDataGroup] = []
for group in groups:
if group is None:
continue
enriched_groups.append(await self._attach_teacher_distillation(group))

payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
if isinstance(scored_data, list):
payload = enriched_groups
else:
payload = enriched_groups[0] if enriched_groups else scored_data

return await super().handle_send_to_api(
payload,
item=item,
do_send_to_api=do_send_to_api,
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
)
Loading
Loading