Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b379dbc
wandb weave tracing integration
ropresearch Nov 5, 2025
6f169bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
e122194
precommit format fixes
ropresearch Nov 6, 2025
6de15ab
Merge branch 'rop/weave' of https://github.com/NousResearch/atropos i…
ropresearch Nov 6, 2025
5c34f48
documentation updates
ropresearch Nov 6, 2025
f5da18f
reverted weave ops, only for base env completions tracing
ropresearch Nov 10, 2025
62fa2ab
format and weave cleanup
ropresearch Nov 10, 2025
2df4ee9
Update to allow tracing flag through config
ropresearch Nov 13, 2025
0179b25
zmq message passing & env data aggregation for wandb
ropresearch Nov 18, 2025
646da3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
f58c927
decoupled zmq from server through run-api cli command subprocess
ropresearch Nov 19, 2025
03cc5e3
Merge branch 'rop/zmq-message-pass' of https://github.com/NousResearc…
ropresearch Nov 19, 2025
e2fe5e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2025
735c14f
Merge pull request #282 from NousResearch/rop/zmq-message-pass
ropresearch Nov 19, 2025
0050d1f
Merge branch 'main' into rop/weave
ropresearch Dec 4, 2025
d1bcadc
tag fixes and ZMQ change for better env categorization
ropresearch Dec 16, 2025
f28c344
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
d227d9a
env controlled wandb log pushes
ropresearch Dec 23, 2025
fba8922
Merge remote rop/weave, resolve sidecar conflicts
ropresearch Dec 23, 2025
7654e3d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 175 additions & 13 deletions atroposlib/api/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import gzip
import logging
import os
import time
import uuid
from typing import Any, Dict, List, Optional

import zmq
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
Expand All @@ -17,6 +20,8 @@
grab_exact_from_heterogeneous_queue,
)

logger = logging.getLogger(__name__)

# Constants
MIN_ENV_WEIGHT = (
0.01 # Minimum weight to prevent environments from being completely starved
Expand Down Expand Up @@ -245,6 +250,19 @@ class Info(BaseModel):
batch_size: int = -1


def send_to_sidecar(payload: Dict[str, Any], port: int):
"""Helper to send payload to ZMQ sidecar (PUSH)."""
try:
context = zmq.Context()
socket = context.socket(zmq.PUSH)
socket.connect(f"tcp://localhost:{port}")
socket.send_pyobj(payload)
socket.close()
context.term()
except Exception as e:
logger.error(f"Failed to send payload to sidecar: {e}")


@app.post("/register")
async def register(registration: Registration):
# Initialize app state if not already done
Expand All @@ -263,6 +281,38 @@ async def register(registration: Registration):
app.state.envs = []
app.state.buffer = {} # Buffer for mixed-size groups per environment

# Init ZMQ config
zmq_port_str = os.getenv("ATROPOS_ZMQ_PORT", "5555")
zmq_port = int(zmq_port_str)
app.state.zmq_port = zmq_port

if registration.wandb_project:
# Resume Logic:
# hash of the group name saved locally so that if server crashes
# or some other issues happens we can keep using the same run for the env isntances
import hashlib

run_id = hashlib.md5(registration.wandb_group.encode()).hexdigest()

# Generate config for sidecar
wandb_config = {
"project": registration.wandb_project,
"group": registration.wandb_group,
"name": f"API_Aggregator_{registration.wandb_group}",
"id": run_id,
"resume": "allow",
"config": registration.model_dump(),
"reinit": True,
}

# Send INIT command to sidecar
send_to_sidecar({"_type": "init", "config": wandb_config}, zmq_port)

# Store run ID for /wandb_info
app.state.wandb_run_id = run_id
else:
app.state.wandb_run_id = None

# Initialize requesters list if not already done
if not hasattr(app.state, "requesters"):
app.state.requesters = []
Expand All @@ -273,23 +323,35 @@ async def register(registration: Registration):

@app.post("/register-env")
async def register_env_url(register_env: RegisterEnv):
# Check if trainer has started
if not hasattr(app.state, "started") or not app.state.started:
return {
"status": "wait for trainer to start",
}
return {"status": "wait for trainer to start"}

# Initialize envs list if not already done
if not hasattr(app.state, "envs"):
app.state.envs = []
if not hasattr(app.state, "env_leaders"):
app.state.env_leaders = {}
if not hasattr(app.state, "next_leader_port"):
app.state.next_leader_port = 5600

# Get checkpoint directory safely
checkpoint_dir = getattr(app.state, "checkpoint_dir", "")
real_name = (
f"{register_env.desired_name}_"
f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}"
instance_index = len(
[x for x in app.state.envs if x["desired_name"] == register_env.desired_name]
)
real_name = f"{register_env.desired_name}_{instance_index}"
registered_id = len(app.state.envs)

is_leader = register_env.desired_name not in app.state.env_leaders
leader_receive_port = None

if is_leader:
leader_receive_port = app.state.next_leader_port
app.state.next_leader_port += 1
app.state.env_leaders[register_env.desired_name] = {
"instance": real_name,
"env_id": registered_id,
"receive_port": leader_receive_port,
}

app.state.envs.append(
{
"max_context_len": register_env.max_token_length,
Expand All @@ -301,8 +363,21 @@ async def register_env_url(register_env: RegisterEnv):
"connected": True,
"min_batch_allocation": register_env.min_batch_allocation,
"group_size": register_env.group_size,
"is_leader": is_leader,
}
)

if hasattr(app.state, "zmq_port"):
msg = {
"_type": "env_register",
"env_type": register_env.desired_name,
"instance": real_name,
"is_leader": is_leader,
}
if is_leader:
msg["leader_receive_port"] = leader_receive_port
send_to_sidecar(msg, app.state.zmq_port)

return {
"status": "success",
"env_id": registered_id,
Expand All @@ -311,13 +386,33 @@ async def register_env_url(register_env: RegisterEnv):
"starting_step": app.state.status_dict["step"],
"checkpoint_interval": app.state.save_checkpoint_interval,
"num_steps": app.state.num_steps,
"is_leader": is_leader,
"leader_receive_port": leader_receive_port,
"wandb_project": getattr(app.state, "project", None),
"wandb_group": getattr(app.state, "group", None),
}


@app.post("/disconnect-env")
async def disconnect_env(disconnect_env: EnvIdentifier):
try:
app.state.envs[disconnect_env.env_id]["connected"] = False
env = app.state.envs[disconnect_env.env_id]
env["connected"] = False

if hasattr(app.state, "zmq_port"):
send_to_sidecar(
{
"_type": "env_disconnect",
"env_type": env["desired_name"],
"instance": env["real_name"],
"was_leader": env.get("is_leader", False),
},
app.state.zmq_port,
)

if env.get("is_leader") and hasattr(app.state, "env_leaders"):
app.state.env_leaders.pop(env["desired_name"], None)

return {"status": "success"}
except (AttributeError, IndexError) as e:
return {"status": "failure", "error": str(e)}
Expand All @@ -326,9 +421,19 @@ async def disconnect_env(disconnect_env: EnvIdentifier):
@app.get("/wandb_info")
async def wandb_info():
try:
return {"group": app.state.group, "project": app.state.project}
return {
"group": app.state.group,
"project": app.state.project,
"zmq_port": getattr(app.state, "zmq_port", None),
"wandb_run_id": getattr(app.state, "wandb_run_id", None),
}
except AttributeError:
return {"group": None, "project": None}
return {
"group": None,
"project": None,
"zmq_port": None,
"wandb_run_id": None,
}


@app.get("/info")
Expand Down Expand Up @@ -409,7 +514,60 @@ async def get_latest_example():

@app.post("/scored_data")
async def scored_data(scored_data: ScoredData):
return _process_scored_data(scored_data)
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"generation_params": scored_data.generation_params,
"inference_logprobs": scored_data.inference_logprobs,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
"env_id": scored_data.env_id,
}
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)

if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []

app.state.buffer[env_id].append(data_dict)

# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)

if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))

# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group

return {
"status": "buffered",
"buffer_size": sum(
len(g["tokens"]) for g in app.state.buffer.get(env_id, [])
),
}

# Normal path - correct size or no env info
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received"}


@app.post("/scored_data_list")
Expand Down Expand Up @@ -564,6 +722,10 @@ async def get_status_env(env: EnvIdentifier):
@app.get("/reset_data")
async def reset_data():
try:
# Send RESET to sidecar
if hasattr(app.state, "zmq_port"):
send_to_sidecar({"_type": "reset"}, app.state.zmq_port)

del app.state.queue
app.state.group = None
app.state.project = None
Expand Down
Loading