Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 164 additions & 1 deletion atroposlib/api/server.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably not change anything on the server.py, and leave the zmq as a separate system

Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import asyncio
import os
import time
import uuid
from contextlib import suppress
from typing import Any, Dict, List, Optional

import wandb
import zmq
import zmq.asyncio
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
Expand All @@ -18,6 +24,9 @@
0.01 # Minimum weight to prevent environments from being completely starved
)

MESSAGE_BUS_ENABLED = os.getenv("ATROPOS_ENABLE_MESSAGE_BUS", "1") != "0"
MESSAGE_BUS_ENDPOINT = os.getenv("ATROPOS_MESSAGE_BUS_ENDPOINT", "tcp://0.0.0.0:5759")

# Message import removed - using Dict[str, Any] for more flexible validation

app = FastAPI(title="AtroposLib API")
Expand Down Expand Up @@ -114,6 +123,108 @@ class Info(BaseModel):
batch_size: int = -1


async def _log_metrics(message: Dict[str, Any]) -> None:
if not getattr(app.state, "wandb_enabled", False):
return
if getattr(app.state, "wandb_run", None) is None:
return

wandb_prepend = message.get("wandb_prepend")
metrics = message.get("metrics") or {}
server_metrics = message.get("server_metrics") or {}
rollouts = message.get("rollouts") or []
step = message.get("step")

metrics_to_log: Dict[str, Any] = {}
if wandb_prepend:
metrics_to_log.update({f"{wandb_prepend}_{k}": v for k, v in metrics.items()})
else:
metrics_to_log.update(metrics)

metrics_to_log.update(server_metrics)

if rollouts:
table = wandb.Table(columns=["text", "score"])
for entry in rollouts:
if isinstance(entry, list):
for text, score in entry:
table.add_data(text, score)
else:
text, score = entry
table.add_data(text, score)
table_key = "train/rollouts"
if wandb_prepend:
table_key = f"{wandb_prepend}_{table_key}"
metrics_to_log[table_key] = table

async with app.state.wandb_lock: # type: ignore[attr-defined]
await asyncio.to_thread(wandb.log, metrics_to_log, step=step)


async def _message_bus_worker() -> None:
socket = getattr(app.state, "message_bus_socket", None)
if socket is None:
return

while True:
try:
message = await socket.recv_json()
except asyncio.CancelledError:
break
except Exception:
continue

token = message.get("token")
if token is None:
continue
env_record = getattr(app.state, "message_bus_tokens", {}).get(token)
if env_record is None:
continue

msg_type = message.get("type")
if msg_type == "metrics":
await _log_metrics(message)


@app.on_event("startup")
async def startup_event() -> None:
if MESSAGE_BUS_ENABLED:
context = zmq.asyncio.Context.instance()
socket = context.socket(zmq.PULL)
socket.setsockopt(zmq.LINGER, 0)
socket.setsockopt(zmq.RCVHWM, 0)
socket.bind(MESSAGE_BUS_ENDPOINT)
app.state.message_bus_context = context
app.state.message_bus_socket = socket
app.state.message_bus_tokens = {}
app.state.message_bus_endpoint = MESSAGE_BUS_ENDPOINT
app.state.message_bus_task = asyncio.create_task(_message_bus_worker())

app.state.wandb_run = None
app.state.wandb_enabled = False
app.state.wandb_lock = asyncio.Lock()


@app.on_event("shutdown")
async def shutdown_event() -> None:
message_bus_task = getattr(app.state, "message_bus_task", None)
if message_bus_task:
message_bus_task.cancel()
with suppress(Exception):
await message_bus_task
socket = getattr(app.state, "message_bus_socket", None)
if socket is not None:
socket.close(linger=0)
app.state.message_bus_socket = None
if getattr(app.state, "wandb_run", None) is not None:
wandb.finish()
app.state.wandb_run = None
if getattr(app.state, "message_bus_context", None) is not None:
context = app.state.message_bus_context
context.term()
app.state.message_bus_context = None


@app.post("/register")
async def register(registration: Registration):
# Initialize app state if not already done
Expand All @@ -131,6 +242,26 @@ async def register(registration: Registration):
app.state.started = False
app.state.envs = []
app.state.buffer = {} # Buffer for mixed-size groups per environment
if MESSAGE_BUS_ENABLED:
app.state.wandb_enabled = True
if getattr(app.state, "wandb_run", None) is not None:
wandb.finish()
wandb_mode = os.getenv("WANDB_MODE")
if wandb_mode is None and not os.getenv("WANDB_API_KEY"):
wandb_mode = "disabled"
init_kwargs: Dict[str, Any] = {
"project": registration.wandb_project,
"group": registration.wandb_group,
"config": {
"batch_size": registration.batch_size,
"max_token_len": registration.max_token_len,
"num_steps": registration.num_steps,
},
"settings": wandb.Settings(start_method="thread"),
}
if wandb_mode is not None:
init_kwargs["mode"] = wandb_mode
app.state.wandb_run = wandb.init(**init_kwargs)

# Initialize requesters list if not already done
if not hasattr(app.state, "requesters"):
Expand Down Expand Up @@ -172,7 +303,7 @@ async def register_env_url(register_env: RegisterEnv):
"group_size": register_env.group_size,
}
)
return {
response = {
"status": "success",
"env_id": registered_id,
"wandb_name": real_name,
Expand All @@ -181,12 +312,35 @@ async def register_env_url(register_env: RegisterEnv):
"checkpoint_interval": app.state.save_checkpoint_interval,
"num_steps": app.state.num_steps,
}
if (
MESSAGE_BUS_ENABLED
and getattr(app.state, "message_bus_endpoint", None) is not None
):
token = uuid.uuid4().hex
app.state.envs[registered_id]["message_token"] = token
if getattr(app.state, "message_bus_tokens", None) is None:
app.state.message_bus_tokens = {}
app.state.message_bus_tokens[token] = {
"registered_id": registered_id,
"desired_name": register_env.desired_name,
"real_name": real_name,
}
response["message_bus"] = {
"endpoint": app.state.message_bus_endpoint,
"token": token,
"env_name": register_env.desired_name,
"wandb_prepend": real_name,
}
return response


@app.post("/disconnect-env")
async def disconnect_env(disconnect_env: EnvIdentifier):
try:
app.state.envs[disconnect_env.env_id]["connected"] = False
token = app.state.envs[disconnect_env.env_id].get("message_token")
if token and getattr(app.state, "message_bus_tokens", None) is not None:
app.state.message_bus_tokens.pop(token, None)
return {"status": "success"}
except (AttributeError, IndexError) as e:
return {"status": "failure", "error": str(e)}
Expand Down Expand Up @@ -533,6 +687,15 @@ async def reset_data():
app.state.requesters = []
app.state.envs = []
app.state.buffer = {}
if getattr(app.state, "message_bus_tokens", None) is not None:
app.state.message_bus_tokens.clear()
except KeyError:
pass
if getattr(app.state, "wandb_run", None) is not None:
wandb.finish()
app.state.wandb_run = None
app.state.wandb_enabled = (
MESSAGE_BUS_ENABLED
and getattr(app.state, "message_bus_socket", None) is not None
)
return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)
Loading