Skip to content
Open
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
f44eb81
teacher env init
J-SUPHA Mar 6, 2026
530fed2
testing set up
J-SUPHA Mar 6, 2026
d5ca760
command change
J-SUPHA Mar 7, 2026
ad364ac
increase timeout cause vllm is super slow all of a sudden
J-SUPHA Mar 8, 2026
985311e
trial
J-SUPHA Mar 8, 2026
e563352
quicker training
J-SUPHA Mar 8, 2026
81f90a6
forgot something easy
J-SUPHA Mar 8, 2026
4f33ab8
apparently not so easy
J-SUPHA Mar 9, 2026
bb2736d
next
J-SUPHA Mar 10, 2026
64794e7
sneaky bug
J-SUPHA Mar 10, 2026
09ad401
sneaky bug logging
J-SUPHA Mar 10, 2026
d1fd89f
non blocking test
J-SUPHA Mar 10, 2026
057c9fe
shorten worker timeout
J-SUPHA Mar 10, 2026
e84686b
remove enforce eager
J-SUPHA Mar 10, 2026
e79af5f
testing config
J-SUPHA Mar 11, 2026
abba562
testing config
J-SUPHA Mar 11, 2026
82be871
testing config
J-SUPHA Mar 11, 2026
98a5d3b
testing config
J-SUPHA Mar 11, 2026
78c0a6d
tokenizer bug
J-SUPHA Mar 11, 2026
f1cfc13
tokenizer bug
J-SUPHA Mar 11, 2026
c275687
tokenizer bug
J-SUPHA Mar 11, 2026
3a440f8
tokenizer bug
J-SUPHA Mar 11, 2026
b457a67
tokenizer bug
J-SUPHA Mar 11, 2026
2f371e0
tokenizer bug
J-SUPHA Mar 11, 2026
8a348be
tokenizer bug
J-SUPHA Mar 11, 2026
34a3936
tokenizer bug
J-SUPHA Mar 12, 2026
fd5b426
tokenizer bug
J-SUPHA Mar 12, 2026
c37516b
tokenizer bug
J-SUPHA Mar 12, 2026
a54dfe7
tokenizer bug
J-SUPHA Mar 12, 2026
62ef2fc
training kernel
J-SUPHA Mar 12, 2026
c26432b
training kernel
J-SUPHA Mar 12, 2026
7ec622a
training ideas
J-SUPHA Mar 12, 2026
a43b0b7
training kernel
J-SUPHA Mar 12, 2026
690e670
investigating weird training issue
J-SUPHA Mar 12, 2026
3df0e45
investigating weird training issue
J-SUPHA Mar 13, 2026
d8857eb
investigating weird training issue
J-SUPHA Mar 13, 2026
d1b0dee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
600c54f
clean log
J-SUPHA Mar 13, 2026
862cd36
clean logging
J-SUPHA Mar 13, 2026
148a4fd
remove training code
J-SUPHA Mar 13, 2026
a1b545c
remove cross tokenization and fix location of configs
J-SUPHA Mar 13, 2026
994e9c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
322e7e6
remove comments
J-SUPHA Mar 13, 2026
a8cdb53
address problems
J-SUPHA Mar 13, 2026
82964b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
697c594
changes
J-SUPHA Mar 13, 2026
6c56479
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
1b8ff07
adding tests
J-SUPHA Mar 13, 2026
12ba3cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
a171358
structural changes
J-SUPHA Mar 13, 2026
3a85ede
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
9bd299b
better logging for devex
J-SUPHA Mar 14, 2026
f053c77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2026
805a0c0
revert to similar structure
J-SUPHA Mar 14, 2026
7aba0d3
fresh eyes check
J-SUPHA Mar 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions atroposlib/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from typing import Any, Dict, List, Optional

from fastapi import FastAPI, status
from fastapi import FastAPI, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import PlainTextResponse
Expand Down Expand Up @@ -351,7 +351,7 @@ async def info():


@app.get("/batch")
async def get_batch():
async def get_batch(request: Request):
# Check if trainer has registered first
if not hasattr(app.state, "started"):
return {
Expand All @@ -363,8 +363,27 @@ async def get_batch():
if not app.state.started:
app.state.started = True

client = request.client
client_addr = (
f"{client.host}:{client.port}" if client is not None else "unknown-client"
)
client_tag = request.headers.get("x-atropos-client", "unknown")
client_pid = request.headers.get("x-atropos-pid", "unknown")

if len(app.state.curr_batch) > 0:
return {"batch": app.state.curr_batch.pop()}
curr_batch = app.state.curr_batch.pop()
logger.warning(
"API /batch returning prebuilt batch to client=%s pid=%s addr=%s: "
"groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
client_tag,
client_pid,
client_addr,
len(curr_batch),
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
len(app.state.queue),
)
return {"batch": curr_batch}
else:
new_batches = []
# Check if any envs have minimum allocations
Expand Down Expand Up @@ -394,16 +413,40 @@ async def get_batch():
)
steps_to_take = len(new_batches)
if steps_to_take == 0:
now = time.time()
last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0)
if now - last_empty_log > 30:
logger.warning(
"API /batch no full batch ready for client=%s pid=%s addr=%s: "
"queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
client_tag,
client_pid,
client_addr,
len(app.state.queue),
sum(len(x.get("tokens", [])) for x in app.state.queue),
len(app.state.curr_batch),
getattr(app.state, "batchsize", -1),
)
app.state._last_empty_batch_log = now
return {"batch": None}
app.state.status_dict["step"] += steps_to_take
# chunk it
for batch in new_batches:
app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop()
# check length before sending
logger.info(
"Sending batch of %s sequences",
logger.warning(
"API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s "
"with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
steps_to_take,
client_tag,
client_pid,
client_addr,
len(curr_batch),
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
len(app.state.queue),
app.state.status_dict["step"],
)
return {"batch": curr_batch}

Expand Down
2 changes: 1 addition & 1 deletion atroposlib/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ async def handle_send_to_api(
"ensure your trainer handles this appropriately."
)
elif abort_on_any_max_length_exceeded and any(
[len(x) >= self.max_token_len for x in group["tokens"]]
[len(x) > self.max_token_len for x in group["tokens"]]
):
logger.warning("Token length is too long in a group, skipping...")
continue
Expand Down
17 changes: 17 additions & 0 deletions atroposlib/envs/server_handling/managed_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,31 @@ async def chat_completion(self, **kwargs) -> ChatCompletion:
if not self.track_tree and self.tokenizer is not None:
input_ids = self._compute_input_ids(prompt, extending_node)
completion_kwargs["input_ids"] = input_ids
logger.warning(
"managed_server chat_completion prepared input_ids=%s extending=%s",
len(input_ids),
extending_node is not None,
)
else:
logger.warning(
"managed_server chat_completion using prompt passthrough track_tree=%s tokenizer=%s",
self.track_tree,
self.tokenizer is not None,
)

# Call the tokens and logprobs wrapper directly
logger.warning("managed_server chat_completion calling backend completion wrapper")
(
prompt_tokens,
output_tokens_list,
output_logprobs_list,
finish_reasons,
) = await self.server.tokens_and_logprobs_completion(**completion_kwargs)
logger.warning(
"managed_server chat_completion backend returned prompt_tokens=%s outputs=%s",
len(prompt_tokens),
len(output_tokens_list),
)

# Track each completion and build choices
n = len(output_tokens_list)
Expand Down
17 changes: 15 additions & 2 deletions atroposlib/envs/server_handling/openai_server.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this may need to be reverted?

Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def resolve_openai_configs(
f"Error creating final OpenAI configuration from merged settings: {e}\n"
f"Merged Dict: {openai_config_dict}"
) from e
server_configs = final_openai_config
server_configs = [final_openai_config]
elif isinstance(default_server_configs, ServerBaseline):
# Pure ServerBaseline (not APIServerConfig) - no CLI overrides possible
logger.info("Using ServerBaseline configuration.")
Expand All @@ -231,7 +231,7 @@ def resolve_openai_configs(
) from e

if isinstance(default_server_configs, APIServerConfig):
server_configs = final_openai_config
server_configs = [final_openai_config]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pass a list of configs here it uses the configs directly. But if you pass a single non list config object, it goes into "template mode" and auto-generates server URLs/ports

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed — the issue was the wrong config shape here. I fixed it so this path now returns [final_openai_config] instead of a bare APIServerConfig.

elif isinstance(default_server_configs, list):
server_configs = [final_openai_config]
else:
Expand All @@ -241,4 +241,17 @@ def resolve_openai_configs(
)
server_configs = [final_openai_config]

if isinstance(server_configs, list):
logger.warning(
"resolve_openai_configs: returning list of %s config(s), URLs: %s",
len(server_configs),
[c.base_url for c in server_configs],
)
else:
logger.warning(
"resolve_openai_configs: returning single %s (base_url=%s) — "
"ServerManager will use template mode!",
type(server_configs).__name__,
getattr(server_configs, "base_url", "N/A"),
)
return server_configs
20 changes: 20 additions & 0 deletions atroposlib/envs/server_handling/server_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import inspect
import logging
import os
import warnings
from contextlib import asynccontextmanager
Expand All @@ -25,6 +26,8 @@
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
from atroposlib.envs.server_handling.vllm_server import VLLMServer

logger = logging.getLogger(__name__)


class ServerManagerConfig(BaseModel):
slurm: bool = Field(
Expand Down Expand Up @@ -103,6 +106,13 @@ def __init__(
self.servers = [ServerHarness()]
return
if not isinstance(configs, list):
logger.warning(
"ServerManager: configs is NOT a list (type=%s). "
"Using auto-generated URLs (template mode). "
"Passed base_url=%s will be IGNORED.",
type(configs).__name__,
getattr(configs, "base_url", "N/A"),
)
urls = []
if os.environ.get("SLURM_JOB_NODELIST", None) is not None:
nodelist = (
Expand Down Expand Up @@ -145,11 +155,21 @@ def __init__(
server_class(config, reasoning_config=reasoning_config)
for config in openai_configs
]
logger.warning(
"ServerManager: auto-generated %s server(s) at URLs: %s",
len(self.servers),
[c.base_url for c in openai_configs],
)
elif not slurm:
self.servers = [
server_class(config, reasoning_config=reasoning_config)
for config in configs
]
logger.warning(
"ServerManager: using %s explicit config(s) at URLs: %s",
len(self.servers),
[c.base_url for c in configs],
)
else:
nodelist = (
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')
Expand Down
27 changes: 27 additions & 0 deletions atroposlib/envs/server_handling/vllm_server.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# see example_trainer/vllm_api_server.py for an example

import asyncio
import logging
import warnings
from typing import Any, Dict, List, Tuple

Expand All @@ -19,6 +20,8 @@
ReasoningConfig,
)

logger = logging.getLogger(__name__)


class VLLMServer(APIServer):
"""
Expand Down Expand Up @@ -190,6 +193,14 @@ async def _tokens_and_logprobs_completion_wrapper(
# Prepare request for VLLM native API
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
request_data.update(kwargs)
logger.warning(
"vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s",
self.config.base_url,
len(prompt_tokens),
request_data.get("n"),
request_data.get("max_tokens"),
request_data.get("temperature"),
)

# Make async request to VLLM /generate endpoint
async with aiohttp.ClientSession() as session:
Expand All @@ -205,6 +216,11 @@ async def _tokens_and_logprobs_completion_wrapper(
) as response:
response.raise_for_status()
results = await response.json()
logger.warning(
"vllm_server completion POST done outputs=%s finish_reasons=%s",
len(results.get("logprobs", [])),
len(results.get("finish_reasons", [])),
)
output_tokens_list = []
output_logprobs_list = []
finish_reasons_list = []
Expand Down Expand Up @@ -314,6 +330,13 @@ async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
request_data["temperature"] = 0.0
request_data["top_p"] = 1.0
request_data.setdefault("max_tokens", 1)
logger.warning(
"vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s",
self.config.base_url,
len(prompt_tokens),
top_k,
request_data.get("max_tokens"),
)

async with aiohttp.ClientSession() as session:
async with session.post(
Expand All @@ -328,6 +351,10 @@ async def _get_logprobs_wrapper(self, **kwargs) -> Dict[str, Any]:
) as response:
response.raise_for_status()
results = await response.json()
logger.warning(
"vllm_server get_logprobs POST done prompt_logprobs_present=%s",
results.get("prompt_logprobs") is not None,
)

raw_prompt_logprobs = results.get("prompt_logprobs")
if raw_prompt_logprobs is None:
Expand Down
Loading
Loading