Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b379dbc
wandb weave tracing integration
ropresearch Nov 5, 2025
6f169bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
e122194
precommit format fixes
ropresearch Nov 6, 2025
6de15ab
Merge branch 'rop/weave' of https://github.com/NousResearch/atropos i…
ropresearch Nov 6, 2025
5c34f48
documentation updates
ropresearch Nov 6, 2025
f5da18f
reverted weave ops, only for base env completions tracing
ropresearch Nov 10, 2025
62fa2ab
format and weave cleanup
ropresearch Nov 10, 2025
2df4ee9
Update to allow tracing flag through config
ropresearch Nov 13, 2025
0179b25
zmq message passing & env data aggregation for wandb
ropresearch Nov 18, 2025
646da3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
f58c927
decoupled zmq from server through run-api cli command subprocess
ropresearch Nov 19, 2025
03cc5e3
Merge branch 'rop/zmq-message-pass' of https://github.com/NousResearc…
ropresearch Nov 19, 2025
e2fe5e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2025
735c14f
Merge pull request #282 from NousResearch/rop/zmq-message-pass
ropresearch Nov 19, 2025
0050d1f
Merge branch 'main' into rop/weave
ropresearch Dec 4, 2025
d1bcadc
tag fixes and ZMQ change for better env categorization
ropresearch Dec 16, 2025
f28c344
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
d227d9a
env controlled wandb log pushes
ropresearch Dec 23, 2025
fba8922
Merge remote rop/weave, resolve sidecar conflicts
ropresearch Dec 23, 2025
7654e3d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 120 additions & 3 deletions atroposlib/api/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import gzip
import logging
import os
import time
import uuid
from typing import Any, Dict, List, Optional

import zmq
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
Expand All @@ -17,6 +20,8 @@
grab_exact_from_heterogeneous_queue,
)

logger = logging.getLogger(__name__)

# Constants
MIN_ENV_WEIGHT = (
0.01 # Minimum weight to prevent environments from being completely starved
Expand Down Expand Up @@ -245,6 +250,19 @@ class Info(BaseModel):
batch_size: int = -1


def send_to_sidecar(payload: Dict[str, Any], port: int):
"""Helper to send payload to ZMQ sidecar (PUSH)."""
try:
context = zmq.Context()
socket = context.socket(zmq.PUSH)
socket.connect(f"tcp://localhost:{port}")
socket.send_pyobj(payload)
socket.close()
context.term()
except Exception as e:
logger.error(f"Failed to send payload to sidecar: {e}")


@app.post("/register")
async def register(registration: Registration):
# Initialize app state if not already done
Expand All @@ -263,6 +281,38 @@ async def register(registration: Registration):
app.state.envs = []
app.state.buffer = {} # Buffer for mixed-size groups per environment

# Init ZMQ config
zmq_port_str = os.getenv("ATROPOS_ZMQ_PORT", "5555")
zmq_port = int(zmq_port_str)
app.state.zmq_port = zmq_port

if registration.wandb_project:
# Resume Logic:
# hash of the group name saved locally so that if server crashes
# or some other issues happens we can keep using the same run for the env isntances
import hashlib

run_id = hashlib.md5(registration.wandb_group.encode()).hexdigest()

# Generate config for sidecar
wandb_config = {
"project": registration.wandb_project,
"group": registration.wandb_group,
"name": f"API_Aggregator_{registration.wandb_group}",
"id": run_id,
"resume": "allow",
"config": registration.model_dump(),
"reinit": True,
}

# Send INIT command to sidecar
send_to_sidecar({"_type": "init", "config": wandb_config}, zmq_port)

# Store run ID for /wandb_info
app.state.wandb_run_id = run_id
else:
app.state.wandb_run_id = None

# Initialize requesters list if not already done
if not hasattr(app.state, "requesters"):
app.state.requesters = []
Expand Down Expand Up @@ -326,9 +376,19 @@ async def disconnect_env(disconnect_env: EnvIdentifier):
@app.get("/wandb_info")
async def wandb_info():
try:
return {"group": app.state.group, "project": app.state.project}
return {
"group": app.state.group,
"project": app.state.project,
"zmq_port": getattr(app.state, "zmq_port", None),
"wandb_run_id": getattr(app.state, "wandb_run_id", None),
}
except AttributeError:
return {"group": None, "project": None}
return {
"group": None,
"project": None,
"zmq_port": None,
"wandb_run_id": None,
}


@app.get("/info")
Expand Down Expand Up @@ -409,7 +469,60 @@ async def get_latest_example():

@app.post("/scored_data")
async def scored_data(scored_data: ScoredData):
return _process_scored_data(scored_data)
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"generation_params": scored_data.generation_params,
"inference_logprobs": scored_data.inference_logprobs,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
"env_id": scored_data.env_id,
}
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)

if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []

app.state.buffer[env_id].append(data_dict)

# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)

if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))

# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group

return {
"status": "buffered",
"buffer_size": sum(
len(g["tokens"]) for g in app.state.buffer.get(env_id, [])
),
}

# Normal path - correct size or no env info
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received"}


@app.post("/scored_data_list")
Expand Down Expand Up @@ -564,6 +677,10 @@ async def get_status_env(env: EnvIdentifier):
@app.get("/reset_data")
async def reset_data():
try:
# Send RESET to sidecar
if hasattr(app.state, "zmq_port"):
send_to_sidecar({"_type": "reset"}, app.state.zmq_port)

del app.state.queue
app.state.group = None
app.state.project = None
Expand Down
132 changes: 132 additions & 0 deletions atroposlib/api/sidecar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import argparse
import logging
import threading
from typing import Any, Dict, Optional

import wandb
import zmq

# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("ZMQSidecar")


class ZMQLogAggregator:
"""
A sidecar service that listens for log data over ZeroMQ and aggregates it
into the centralized WandB run.
"""

def __init__(self, port: int = 5555, context: Optional[zmq.Context] = None):
self.port = port
self.context = context or zmq.Context()
self.socket = self.context.socket(zmq.PULL)
self.running = False
self.thread = None

def start(self):
"""Start the aggregator thread."""
if self.running:
return

try:
self.socket.bind(f"tcp://*:{self.port}")
logger.info(f"ZMQLogAggregator listening on port {self.port}")
except zmq.ZMQError as e:
logger.error(f"Failed to bind ZMQ socket on port {self.port}: {e}")
raise

self.running = True
# In process mode, we run directly, not in a thread
self._loop()

def stop(self):
"""Stop the aggregator."""
self.running = False
try:
self.socket.close()
except Exception:
pass

def _handle_control_message(self, payload: Dict[str, Any]):
"""Handle control messages for lifecycle management."""
msg_type = payload.get("_type")

if msg_type == "init":
config = payload.get("config", {})
logger.info(
f"Received INIT command. Starting WandB run: {config.get('group', 'unknown')}"
)

# Make sure we finish any existing run
if wandb.run is not None:
logger.info("Finishing existing WandB run before starting new one")
wandb.finish()

try:
wandb.init(**config)
logger.info(f"WandB run initialized: {wandb.run.id}")
except Exception as e:
logger.error(f"Failed to initialize WandB: {e}")

elif msg_type == "reset":
logger.info("Received RESET command. Finishing WandB run.")
if wandb.run is not None:
wandb.finish()
else:
logger.info("No active WandB run to finish.")

def _loop(self):
"""Main listening loop."""
poller = zmq.Poller()
poller.register(self.socket, zmq.POLLIN)

logger.info("ZMQ Sidecar loop started")

while self.running:
try:
# check if open
socks = dict(poller.poll(1000))
if self.socket in socks:
# pyobj in case of some other data stuff later
payload = self.socket.recv_pyobj()

# Check if it's a control message
if isinstance(payload, dict) and "_type" in payload:
self._handle_control_message(payload)
continue

# Otherwise treat as log payload
if wandb.run is not None:
wandb.log(payload)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should not be controlled here, it should be routed back to the environment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you will also need logic to wait for all connected environments to figure out when to send it back

else:
# Optional: accumulate logs buffer or just debug log
# For now, we just debug log to avoid memory leaks
pass
# logger.debug("Received log payload (wandb not active)")

except Exception as e:
logger.error(f"Error in ZMQLogAggregator loop: {e}")
# Don't break on transient errors, but logging essential
# if not self.running:
# break


def main():
parser = argparse.ArgumentParser(description="Atropos ZMQ Logging Sidecar")
parser.add_argument("--port", type=int, default=5555, help="Port to listen on")
args = parser.parse_args()

aggregator = ZMQLogAggregator(port=args.port)
try:
aggregator.start()
except KeyboardInterrupt:
logger.info("Stopping ZMQ Sidecar...")
aggregator.stop()


if __name__ == "__main__":
main()
43 changes: 33 additions & 10 deletions atroposlib/cli/inference_node_wandb_watcher.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,66 @@
import argparse
import time
from urllib.parse import urlparse

import requests
import wandb


def update_wandb(health_statuses):
wandb.log(health_statuses)
from atroposlib.utils.logging_client import ZMQLogger


def run(api_addr, tp, node_num):
print(f"Starting up with {api_addr}, {tp}, {node_num}", flush=True)
zmq_logger = None

while True:
try:
data = requests.get(f"{api_addr}/wandb_info").json()
wandb_group = data["group"]
wandb_project = data["project"]
wandb_group = data.get("group")
wandb_project = data.get("project")
zmq_port = data.get("zmq_port")
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
wandb_project = None
wandb_group = None
zmq_port = None
print("Waiting for init...")

if wandb_project is None:
time.sleep(1)
else:
if zmq_port:
try:
parsed = urlparse(api_addr)
host = parsed.hostname or "localhost"
zmq_addr = f"tcp://{host}:{zmq_port}"
zmq_logger = ZMQLogger(address=zmq_addr)
print(f"Connected to ZMQ Logger at {zmq_addr}")
break
except Exception as e:
print(f"Failed to connect ZMQ: {e}")
# does our existing/old wandb setup if zmq isn't open

wandb.init(
project=wandb_project, group=wandb_group, name=f"inf_node_{node_num}"
)
break

curr_step = 0
health_statuses = {
f"server/server_health_{node_num}_{i}": 0.0 for i in range(8 // tp)
}
while True:
data = requests.get(f"{api_addr}/status").json()
step = data["current_step"]
if step > curr_step:
wandb.log(health_statuses, step=step)
curr_step = step
try:
data = requests.get(f"{api_addr}/status").json()
step = data["current_step"]
if step > curr_step:
if zmq_logger:
zmq_logger.log(health_statuses, step=step)
else:
wandb.log(health_statuses, step=step)
curr_step = step
except Exception as e:
print(f"Error fetching status: {e}")

time.sleep(60)
# Check on each server
for i in range(8 // tp):
Expand Down
Loading