Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b379dbc
wandb weave tracing integration
ropresearch Nov 5, 2025
6f169bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
e122194
precommit format fixes
ropresearch Nov 6, 2025
6de15ab
Merge branch 'rop/weave' of https://github.com/NousResearch/atropos i…
ropresearch Nov 6, 2025
5c34f48
documentation updates
ropresearch Nov 6, 2025
f5da18f
reverted weave ops, only for base env completions tracing
ropresearch Nov 10, 2025
62fa2ab
format and weave cleanup
ropresearch Nov 10, 2025
2df4ee9
Update to allow tracing flag through config
ropresearch Nov 13, 2025
0179b25
zmq message passing & env data aggregation for wandb
ropresearch Nov 18, 2025
646da3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
f58c927
decoupled zmq from server through run-api cli command subprocess
ropresearch Nov 19, 2025
03cc5e3
Merge branch 'rop/zmq-message-pass' of https://github.com/NousResearc…
ropresearch Nov 19, 2025
e2fe5e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2025
735c14f
Merge pull request #282 from NousResearch/rop/zmq-message-pass
ropresearch Nov 19, 2025
0050d1f
Merge branch 'main' into rop/weave
ropresearch Dec 4, 2025
d1bcadc
tag fixes and ZMQ change for better env categorization
ropresearch Dec 16, 2025
f28c344
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
d227d9a
env controlled wandb log pushes
ropresearch Dec 23, 2025
fba8922
Merge remote rop/weave, resolve sidecar conflicts
ropresearch Dec 23, 2025
7654e3d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion atroposlib/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ async def scored_data(scored_data: ScoredData):
"images": scored_data.images,
"env_id": scored_data.env_id,
}

# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
Expand Down
25 changes: 25 additions & 0 deletions atroposlib/envs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,28 @@ These class-level variables in `BaseEnv` can be overridden in your subclass to c
* **CLI Integration**: Provides a `cli()` class method using `pydantic-cli` to easily create command-line interfaces for your environment (e.g., `python your_env_module.py serve --port 8001 ...`). See `get_cli_serve_config_cls` and `get_cli_process_config_cls`.

By implementing the required methods and optionally overriding others, you can create diverse environments that leverage the distributed training infrastructure provided by the `Atropos` framework.

---

## Weave tracing in environments

Environments emit Weave traces to help you inspect rollout flow and LLM calls:

- Enabled by default; disable via config or env:
- Config: under your OpenAI server settings, set `tracing_enabled: false`
- YAML example:
```yaml
openai:
model_name: your-model
base_url: http://localhost:9000
tracing_enabled: false
```
- CLI example: `--openai--tracing_enabled false`
- Env (hard disable): `WEAVE_DISABLED=true`
- Optional project override via `WEAVE_PROJECT=<your-project-name>`.
- Traces include:
- Environment operations (group collection and send-to-API)
- LLM calls to OpenAI-compatible providers (including SGLang/TRL vLLM wrappers)
- Useful attributes (env name/id, model name, base URL, group/batch sizes)

View your traces at `https://weave.wandb.ai` under your project.
7 changes: 6 additions & 1 deletion atroposlib/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import aiohttp
import jsonlines
import numpy as np
import wandb
import yaml
from pydantic import BaseModel, Field
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
Expand All @@ -27,6 +26,7 @@
from transformers import AutoTokenizer
from typing_extensions import TypedDict

import wandb
from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
from atroposlib.frontend.jsonl2html import generate_html
Expand Down Expand Up @@ -766,6 +766,7 @@ async def _send_scored_data_to_api(self, scored_data):
if isinstance(scored_data, list)
else f"{self.config.rollout_server_url}/scored_data"
)

async with aiohttp.ClientSession() as session:
async with self._post_json_with_compression(
session,
Expand Down Expand Up @@ -803,6 +804,10 @@ def _post_json_with_compression(

return session.post(url, data=body, headers=headers)

@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
)
async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
Expand Down
194 changes: 167 additions & 27 deletions atroposlib/envs/server_handling/server_baseline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import asyncio
import collections
import functools
import inspect
import os
import time
from abc import ABC, abstractmethod
from asyncio import exceptions
from contextlib import contextmanager
from typing import Literal, Optional

import numpy as np
Expand All @@ -11,6 +15,103 @@
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_random_exponential

try:
import weave as _weave
except Exception:
_weave = None
try:
import wandb as _wandb
except Exception:
_wandb = None

WEAVE_ENABLED = _weave is not None and os.getenv(
"WEAVE_DISABLED", "false"
).lower() not in ("1", "true", "yes")
_weave_initialized = False


def _resolve_weave_project_name() -> str:

if _wandb is not None:
try:
run = getattr(_wandb, "run", None)
project_from_run = (
getattr(run, "project", None) if run is not None else None
)
if project_from_run:
return str(project_from_run)
except Exception:
pass

project_from_env = os.getenv("WANDB_PROJECT")
if project_from_env:
return project_from_env

project_from_weave_env = os.getenv("WEAVE_PROJECT")
if project_from_weave_env:
return project_from_weave_env

return "atropos"


def ensure_weave_init() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be populated by the group name?

global _weave_initialized
if WEAVE_ENABLED and not _weave_initialized:
project_name = _resolve_weave_project_name()
try:
_weave.init(project_name)
except Exception:

pass
_weave_initialized = True


def weave_op(func):
if WEAVE_ENABLED:
wrapped = _weave.op(func)
if inspect.iscoroutinefunction(func):

@functools.wraps(wrapped)
async def init_then_call(*args, **kwargs):
# Gate tracing via per-instance config if available
tracing_enabled = True
try:
if len(args) > 0 and getattr(args[0], "config", None) is not None:
tracing_enabled = getattr(args[0].config, "tracing_enabled", True)
except Exception:
tracing_enabled = True
if not tracing_enabled:
return await func(*args, **kwargs)
ensure_weave_init()
return await wrapped(*args, **kwargs)

else:

@functools.wraps(wrapped)
def init_then_call(*args, **kwargs):
tracing_enabled = True
try:
if len(args) > 0 and getattr(args[0], "config", None) is not None:
tracing_enabled = getattr(args[0].config, "tracing_enabled", True)
except Exception:
tracing_enabled = True
if not tracing_enabled:
return func(*args, **kwargs)
ensure_weave_init()
return wrapped(*args, **kwargs)

return init_then_call
return func


@contextmanager
def weave_attributes(attrs: dict):
if WEAVE_ENABLED:
with _weave.attributes(attrs):
yield
else:
yield


class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
def __init__(self, value: int):
Expand Down Expand Up @@ -111,6 +212,10 @@ class ServerBaseline(BaseModel):
server_type: Literal["openai", "trl", "sglang"] = Field(
default="openai", description="Type of server to use, openai or trl"
)
tracing_enabled: bool = Field(
default=True,
description="Enable Weave tracing for chat/completion ops (overridden by WEAVE_DISABLED).",
)


class APIServerConfig(ServerBaseline):
Expand Down Expand Up @@ -264,6 +369,7 @@ async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion:
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
@weave_op
async def chat_completion(self, **kwargs) -> ChatCompletion:
"""
Chat completion handler, waits for the server to be healthy and then calls the chat completion wrapper.
Expand All @@ -285,15 +391,25 @@ async def chat_completion(self, **kwargs) -> ChatCompletion:
split = kwargs.pop("split", "train")
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
ret_data = await self._chat_comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
# Give separate eval workers, if desired, gotta go fast for those evals
ret_data = await self._chat_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
with weave_attributes(
{
"server_type": getattr(self.config, "server_type", None),
"endpoint": "chat_completion",
"model": self.config.model_name,
"base_url": getattr(self.config, "base_url", None),
"split": split,
"n": kwargs.get("n", 1),
}
):
if split == "train":
ret_data = await self._chat_comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:

ret_data = await self._chat_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data

@retry(
Expand Down Expand Up @@ -330,6 +446,7 @@ async def _comp_eval(self, stat_dict, **kwargs) -> Completion:
stat_dict["end"] = time.time()
return completions

@weave_op
async def completion(self, **kwargs) -> Completion:
"""
Completion handler, waits for the server to be healthy and then calls the completion wrapper.
Expand All @@ -352,20 +469,31 @@ async def completion(self, **kwargs) -> Completion:
split = kwargs.pop("split", "train")
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
ret_data = await self._comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
# Give separate eval workers, if desired, gotta go fast for those evals
ret_data = await self._comp_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
with weave_attributes(
{
"server_type": getattr(self.config, "server_type", None),
"endpoint": "completion",
"model": self.config.model_name,
"base_url": getattr(self.config, "base_url", None),
"split": split,
"n": kwargs.get("n", 1),
}
):
if split == "train":
ret_data = await self._comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:

ret_data = await self._comp_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data

@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
@weave_op
async def _tokens_and_logprobs_comp(
self, stat_dict, **kwargs
) -> tuple[list, list, list, list]:
Expand Down Expand Up @@ -426,13 +554,25 @@ async def tokens_and_logprobs_completion(
split = kwargs.pop("split", "train")
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
ret_data = await self._tokens_and_logprobs_comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
# Give separate eval workers, if desired, gotta go fast for those evals
ret_data = await self._tokens_and_logprobs_comp_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
with weave_attributes(
{
"server_type": getattr(self.config, "server_type", None),
"endpoint": "tokens_and_logprobs_completion",
"model": self.config.model_name,
"base_url": getattr(self.config, "base_url", None),
"split": split,
"n": kwargs.get("n", 1),
}
):
if split == "train":
ret_data = await self._tokens_and_logprobs_comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:

ret_data = await self._tokens_and_logprobs_comp_eval(
stat_dict, **kwargs
)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"markdown",
"numpy",
"wandb",
"weave",
"gymnasium",
"math-verify==0.7.0",
"jinja2",
Expand Down