Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b379dbc
wandb weave tracing integration
ropresearch Nov 5, 2025
6f169bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
e122194
precommit format fixes
ropresearch Nov 6, 2025
6de15ab
Merge branch 'rop/weave' of https://github.com/NousResearch/atropos i…
ropresearch Nov 6, 2025
5c34f48
documentation updates
ropresearch Nov 6, 2025
f5da18f
reverted weave ops, only for base env completions tracing
ropresearch Nov 10, 2025
62fa2ab
format and weave cleanup
ropresearch Nov 10, 2025
2df4ee9
Update to allow tracing flag through config
ropresearch Nov 13, 2025
0179b25
zmq message passing & env data aggregation for wandb
ropresearch Nov 18, 2025
646da3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
f58c927
decoupled zmq from server through run-api cli command subprocess
ropresearch Nov 19, 2025
03cc5e3
Merge branch 'rop/zmq-message-pass' of https://github.com/NousResearc…
ropresearch Nov 19, 2025
e2fe5e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2025
735c14f
Merge pull request #282 from NousResearch/rop/zmq-message-pass
ropresearch Nov 19, 2025
0050d1f
Merge branch 'main' into rop/weave
ropresearch Dec 4, 2025
d1bcadc
tag fixes and ZMQ change for better env categorization
ropresearch Dec 16, 2025
f28c344
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
d227d9a
env controlled wandb log pushes
ropresearch Dec 23, 2025
fba8922
Merge remote rop/weave, resolve sidecar conflicts
ropresearch Dec 23, 2025
7654e3d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions atroposlib/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ uvicorn atroposlib.api.server:app --host 0.0.0.0 --port 8000 --reload

The API documentation (Swagger UI) will be available at `http://<your-server-ip>:8000/docs`.

### Weave tracing for API submissions

This service emits lightweight Weave traces for rollout submissions so you can inspect various behavior while training:

- Enabled by default. Disable with `WEAVE_DISABLED=true`.
- Optional project name via `WEAVE_PROJECT` (defaults to `atropos-api`).
- `/scored_data` shows whether data was buffered or enqueued and the queue lengths before/after.

View traces in your project at `https://weave.wandb.ai`.

## API Endpoints

### General
Expand Down
60 changes: 58 additions & 2 deletions atroposlib/api/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import gzip
import os
import time
import uuid
from typing import Any, Dict, List, Optional

import weave
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
Expand Down Expand Up @@ -37,6 +39,33 @@

app.add_middleware(GZipMiddleware, minimum_size=1000)

if os.getenv("WEAVE_DISABLED", "false").lower() not in ("1", "true", "yes"):
try:
weave.init(os.getenv("WEAVE_PROJECT", "atropos-api"))
except Exception:
pass


@weave.op
def _trace_api_enq(
endpoint: str,
env_id: Optional[int],
group_size: int,
buffered: bool,
buffer_size: int,
queue_len_before: int,
queue_len_after: int,
):
return {
"endpoint": endpoint,
"env_id": env_id,
"group_size": group_size,
"buffered": buffered,
"buffer_size": buffer_size,
"queue_len_before": queue_len_before,
"queue_len_after": queue_len_after,
}


class GZipRequestMiddleware:

Expand Down Expand Up @@ -361,7 +390,9 @@ async def scored_data(scored_data: ScoredData):
"images": scored_data.images,
"env_id": scored_data.env_id,
}

queue_before = (
len(getattr(app.state, "queue", [])) if hasattr(app.state, "queue") else 0
)
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
Expand Down Expand Up @@ -391,16 +422,41 @@ async def scored_data(scored_data: ScoredData):
app.state.queue.append(group)
app.state.latest = group

return {
ret = {
"status": "buffered",
"buffer_size": sum(
len(g["tokens"]) for g in app.state.buffer.get(env_id, [])
),
}
try:
await _trace_api_enq(
endpoint="/scored_data",
env_id=env_id,
group_size=actual_group_size,
buffered=True,
buffer_size=ret["buffer_size"],
queue_len_before=queue_before,
queue_len_after=len(app.state.queue),
)
except Exception:
pass
return ret

# Normal path - correct size or no env info
app.state.queue.append(data_dict)
app.state.latest = data_dict
try:
await _trace_api_enq(
endpoint="/scored_data",
env_id=env_id,
group_size=len(scored_data.tokens),
buffered=False,
buffer_size=0,
queue_len_before=queue_before,
queue_len_after=len(app.state.queue),
)
except Exception:
pass
return {"status": "received"}


Expand Down
15 changes: 15 additions & 0 deletions atroposlib/envs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,18 @@ These class-level variables in `BaseEnv` can be overridden in your subclass to c
* **CLI Integration**: Provides a `cli()` class method using `pydantic-cli` to easily create command-line interfaces for your environment (e.g., `python your_env_module.py serve --port 8001 ...`). See `get_cli_serve_config_cls` and `get_cli_process_config_cls`.

By implementing the required methods and optionally overriding others, you can create diverse environments that leverage the distributed training infrastructure provided by the `Atropos` framework.

---

## Weave tracing in environments

Environments emit Weave traces to help you inspect rollout flow and LLM calls:

- Enabled by default; disable with `WEAVE_DISABLED=true`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to put this into the config dict?

- Optional project override via `WEAVE_PROJECT=<your-project-name>`.
- Traces include:
- Environment operations (group collection and send-to-API)
- LLM calls to OpenAI-compatible providers (including SGLang/TRL vLLM wrappers)
- Useful attributes (env name/id, model name, base URL, group/batch sizes)

View your traces at `https://weave.wandb.ai` under your project.
99 changes: 83 additions & 16 deletions atroposlib/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jsonlines
import numpy as np
import wandb
import weave
import yaml
from pydantic import BaseModel, Field
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
Expand Down Expand Up @@ -287,6 +288,7 @@ async def collect_trajectory(
"Handle env single method must be implemented in subclass "
)

@weave.op
async def collect_trajectories(self, item: Item) -> Tuple[
Union[
Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]
Expand Down Expand Up @@ -750,6 +752,7 @@ async def evaluate_log(
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
)
@weave.op
async def _send_scored_data_to_api(self, scored_data):
"""
Send scored data to the API with retry logic for timeouts and server errors.
Expand All @@ -766,19 +769,33 @@ async def _send_scored_data_to_api(self, scored_data):
if isinstance(scored_data, list)
else f"{self.config.rollout_server_url}/scored_data"
)
async with aiohttp.ClientSession() as session:
async with self._post_json_with_compression(
session,
url,
scored_data,
) as resp:
if resp.status >= 500:
logging.debug(f"Server error: {resp.status}, retrying...")
raise Exception(f"Server error: {resp.status}")
elif resp.status >= 400:
logging.error(f"Client error: {resp.status}, not retrying")
return
print(await resp.text())
payload_len = (
sum(len(g.get("tokens", [])) for g in scored_data)
if isinstance(scored_data, list)
else len(scored_data.get("tokens", []))
)
with weave.attributes(
{
"endpoint": url,
"payload_groups": (
len(scored_data) if isinstance(scored_data, list) else 1
),
"payload_sequences": payload_len,
}
):
async with aiohttp.ClientSession() as session:
async with self._post_json_with_compression(
session,
url,
scored_data,
) as resp:
if resp.status >= 500:
logging.debug(f"Server error: {resp.status}, retrying...")
raise Exception(f"Server error: {resp.status}")
elif resp.status >= 400:
logging.error(f"Client error: {resp.status}, not retrying")
return
print(await resp.text())

def _post_json_with_compression(
self,
Expand All @@ -803,6 +820,7 @@ def _post_json_with_compression(

return session.post(url, data=body, headers=headers)

@weave.op
async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
Expand Down Expand Up @@ -898,6 +916,7 @@ async def handle_send_to_api(
)
print(f"Failed to send {data_type_str} after retries: {e}")

@weave.op
async def handle_env(
self, item_uuid: str
) -> Optional[Union[ScoredDataGroup, List[ScoredDataGroup]]]:
Expand All @@ -909,7 +928,16 @@ async def handle_env(
print(f"item {item_uuid} not found... returning")
return None
start_time = time.time()
logger.debug(f"handle_env: Starting with item: {item}")
with weave.attributes(
{
"env_id": getattr(self, "env_id", None),
"env_name": getattr(self, "wandb_prepend", None)
or self.name
or self.__class__.__name__,
"step": self.curr_step,
}
):
logger.debug(f"handle_env: Starting with item: {item}")
# do a rollout with item
try:
to_postprocess, to_backlog = await self.collect_trajectories(item)
Expand Down Expand Up @@ -1129,6 +1157,15 @@ async def env_manager(self):
"""
await self.setup()
await self.setup_wandb()
# Initialize Weave tracing once per process (if not disabled)
if os.getenv("WEAVE_DISABLED", "false").lower() not in ("1", "true", "yes"):
project_name = os.getenv("WEAVE_PROJECT") or (
self.wandb_project or "atropos"
)
try:
weave.init(project_name)
except Exception:
pass
await self.register_env()
await self.get_server_info()
# Wait for other instances to get setup :)
Expand All @@ -1140,8 +1177,38 @@ async def env_manager(self):
)
# get status from server
self.last_loop_time = time.time()
await self.get_status()
await self.env_step_checks()
with weave.attributes(
{
"env_id": getattr(self, "env_id", None),
"env_name": getattr(self, "wandb_prepend", None)
or self.name
or self.__class__.__name__,
"wandb_project": self.wandb_project,
"wandb_group": self.wandb_group,
"model_name": (
getattr(
getattr(self.server.servers[0], "config", object()),
"model_name",
None,
)
if self.server.servers
else None
),
"base_url": (
getattr(
getattr(self.server.servers[0], "config", object()),
"base_url",
None,
)
if self.server.servers
else None
),
"group_size": self.config.group_size,
"batch_size": self.config.batch_size,
}
):
await self.get_status()
await self.env_step_checks()
logger.info(f"env_manager: Status dict: {self.status_dict}")
if (
self.status_dict["current_step"]
Expand Down
Loading